Splitters

Splits data into train and test

Below is the documentation of the core splitter classes implemented in RePlay. For practical use, splitters can be composed and combined with auxiliary utilities from RePlay (see Preprocessing) to obtain different data partitioning schemes.

As proposed in the paper Time to Split: Exploring Data Splitting Strategies for Offline Evaluation of Sequential Recommenders (RecSys’25), advanced data‑splitting schemes — Global Temporal Split (GTS) with last interaction as the target and GTS with a random interaction as the target, can be implemented in RePlay by composing TimeSplitter with either LastNSplitter (e.g., N=1) or RandomNextNSplitter (e.g., N=1). These pipelines can be complemented with auxiliary utilities, such as cold‑start filtering via filter_cold() (see filters and Preprocessing) and dataset merging via merge_subsets(). For an end‑to‑end illustration, see examples/04_splitters.ipynb.

Splits are returned with split method.

replay.splitters.base_splitter.Splitter.split(self, interactions)

Splits input DataFrame into train and test

Parameters

interactions (Union[DataFrame, DataFrame, DataFrame]) – input DataFrame [timestamp, user_id, item_id, relevance]

Returns

List of splitted DataFrames

Return type

tuple[Union[pandas.core.frame.DataFrame, pyspark.sql.dataframe.DataFrame, polars.dataframe.frame.DataFrame], Union[pandas.core.frame.DataFrame, pyspark.sql.dataframe.DataFrame, polars.dataframe.frame.DataFrame]]

TwoStageSplitter

class replay.splitters.two_stage_splitter.TwoStageSplitter(first_divide_size, second_divide_size, first_divide_column='query_id', second_divide_column='item_id', shuffle=False, drop_cold_items=False, drop_cold_users=False, seed=None, query_column='query_id', item_column='item_id', timestamp_column='timestamp')

Split data by two columns. First step: takes first_divide_size distinct values of first_divide_column to the test split. Second step: takes second_divide_size of second_divide_column among the data provided after first step to the test split.

Example:

>>> from replay.utils.session_handler import get_spark_session, State
>>> spark = get_spark_session(1, 1)
>>> state = State(spark)
>>> from replay.splitters import TwoStageSplitter
>>> import pandas as pd
>>> data_frame = pd.DataFrame({"query_id": [1,1,1,2,2,2],
...    "item_id": [1,2,3,1,2,3],
...    "relevance": [1,2,3,4,5,6],
...    "timestamp": [1,2,3,3,2,1]})
>>> data_frame
   query_id  item_id  relevance  timestamp
0         1         1          1          1
1         1         2          2          2
2         1         3          3          3
3         2         1          4          3
4         2         2          5          2
5         2         3          6          1
>>> train, test = TwoStageSplitter(first_divide_size=1, second_divide_size=2, seed=42).split(data_frame)
>>> test
   query_id  item_id  relevance  timestamp
3         2         1          4          3
4         2         2          5          2
>>> train, test = TwoStageSplitter(first_divide_size=0.5, second_divide_size=2, seed=42).split(data_frame)
>>> test
   query_id  item_id  relevance  timestamp
3         2         1          4          3
4         2         2          5          2
>>> train, test = TwoStageSplitter(first_divide_size=0.5, second_divide_size=0.7, seed=42).split(data_frame)
>>> test
   query_id  item_id  relevance  timestamp
3         2         1          4          3
4         2         2          5          2
__init__(first_divide_size, second_divide_size, first_divide_column='query_id', second_divide_column='item_id', shuffle=False, drop_cold_items=False, drop_cold_users=False, seed=None, query_column='query_id', item_column='item_id', timestamp_column='timestamp')
Parameters
  • second_divide_size (float) – fraction or a number of items per user

  • first_divide_size (float) – similar to item_test_size, but corresponds to the number of users. None is all available users.

  • shuffle – take random items and not last based on timestamp.

  • drop_cold_items (bool) – flag to drop cold items from test

  • drop_cold_users (bool) – flag to drop cold users from test

  • seed (Optional[int]) – random seed

  • query_column (str) – query id column name

  • item_column (Optional[str]) – item id column name

  • timestamp_column (Optional[str]) – timestamp column name

KFolds

replay.splitters.k_folds.KFolds(n_folds=5, strategy='query', drop_cold_items=False, drop_cold_users=False, seed=None, query_column='query_id', item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')

Splits interactions inside each query into folds at random.

TimeSplitter

class replay.splitters.time_splitter.TimeSplitter(time_threshold, query_column='query_id', drop_cold_users=False, drop_cold_items=False, item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test', time_column_format='%Y-%m-%d %H:%M:%S')

Split interactions by time.

>>> from datetime import datetime
>>> import pandas as pd
>>> columns = ["query_id", "item_id", "timestamp"]
>>> data = [
...     (1, 1, "01-01-2020"),
...     (1, 2, "02-01-2020"),
...     (1, 3, "03-01-2020"),
...     (1, 4, "04-01-2020"),
...     (1, 5, "05-01-2020"),
...     (2, 1, "06-01-2020"),
...     (2, 2, "07-01-2020"),
...     (2, 3, "08-01-2020"),
...     (2, 9, "09-01-2020"),
...     (2, 10, "10-01-2020"),
...     (3, 1, "01-01-2020"),
...     (3, 5, "02-01-2020"),
...     (3, 3, "03-01-2020"),
...     (3, 1, "04-01-2020"),
...     (3, 2, "05-01-2020"),
... ]
>>> dataset = pd.DataFrame(data, columns=columns)
>>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
>>> dataset
   query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
2         1        3 2020-01-03
3         1        4 2020-01-04
4         1        5 2020-01-05
5         2        1 2020-01-06
6         2        2 2020-01-07
7         2        3 2020-01-08
8         2        9 2020-01-09
9         2       10 2020-01-10
10        3        1 2020-01-01
11        3        5 2020-01-02
12        3        3 2020-01-03
13        3        1 2020-01-04
14        3        2 2020-01-05
>>> train, test = TimeSplitter(
...     time_threshold=datetime.strptime("2020-01-04", "%Y-%M-%d")
... ).split(dataset)
>>> train
   query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
2         1        3 2020-01-03
3         1        4 2020-01-04
10        3        1 2020-01-01
11        3        5 2020-01-02
12        3        3 2020-01-03
13        3        1 2020-01-04
>>> test
   query_id  item_id  timestamp
4         1        5 2020-01-05
5         2        1 2020-01-06
6         2        2 2020-01-07
7         2        3 2020-01-08
8         2        9 2020-01-09
9         2       10 2020-01-10
14        3        2 2020-01-05
__init__(time_threshold, query_column='query_id', drop_cold_users=False, drop_cold_items=False, item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test', time_column_format='%Y-%m-%d %H:%M:%S')
Parameters
  • time_threshold (Union[datetime, str, float]) – Test threshold, can be datetime, string, int or float. datetime is in case of splitting by datetime, int is in case of splitting by datetime (Unix format), string will be converted to datetime using time_column_format, float is in case of splitting by ratio, the value must be between 0 and 1.

  • query_column (str) – Name of user interaction column.

  • drop_cold_users (bool) – Drop users from test DataFrame. which are not in train DataFrame, default: False.

  • drop_cold_items (bool) – Drop items from test DataFrame which are not in train DataFrame, default: False.

  • item_column (str) – Name of item interaction column. If drop_cold_items is False, then you can omit this parameter. Default: item_id.

  • timestamp_column (str) – Name of time column, Default: timestamp.

  • session_id_column (Optional[str]) – Name of session id column, which values can not be split, default: None.

  • session_id_processing_strategy (str) – strategy of processing session if it is split, values: train, test, train: whole split session goes to train. test: same but to test. default: test.

LastNSplitter

class replay.splitters.last_n_splitter.LastNSplitter(N, divide_column='query_id', time_column_format='yyyy-MM-dd HH:mm:ss', strategy='interactions', drop_cold_users=False, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')

Split interactions by last N interactions/timedelta per user. Type of splitting depends on the strategy parameter.

>>> from datetime import datetime
>>> import pandas as pd
>>> columns = ["query_id", "item_id", "timestamp"]
>>> data = [
...     (1, 1, "01-01-2020"),
...     (1, 2, "02-01-2020"),
...     (1, 3, "03-01-2020"),
...     (1, 4, "04-01-2020"),
...     (1, 5, "05-01-2020"),
...     (2, 1, "06-01-2020"),
...     (2, 2, "07-01-2020"),
...     (2, 3, "08-01-2020"),
...     (2, 9, "09-01-2020"),
...     (2, 10, "10-01-2020"),
...     (3, 1, "01-01-2020"),
...     (3, 5, "02-01-2020"),
...     (3, 3, "03-01-2020"),
...     (3, 1, "04-01-2020"),
...     (3, 2, "05-01-2020"),
... ]
>>> dataset = pd.DataFrame(data, columns=columns)
>>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
>>> dataset
   query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
2         1        3 2020-01-03
3         1        4 2020-01-04
4         1        5 2020-01-05
5         2        1 2020-01-06
6         2        2 2020-01-07
7         2        3 2020-01-08
8         2        9 2020-01-09
9         2       10 2020-01-10
10        3        1 2020-01-01
11        3        5 2020-01-02
12        3        3 2020-01-03
13        3        1 2020-01-04
14        3        2 2020-01-05
>>> splitter = LastNSplitter(
...     N=2,
...     divide_column="query_id",
...     time_column_format="yyyy-MM-dd",
...     query_column="query_id",
...     item_column="item_id"
... )
>>> train, test = splitter.split(dataset)
>>> train
   query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
2         1        3 2020-01-03
5         2        1 2020-01-06
6         2        2 2020-01-07
7         2        3 2020-01-08
10        3        1 2020-01-01
11        3        5 2020-01-02
12        3        3 2020-01-03
>>> test
   query_id  item_id  timestamp
3         1        4 2020-01-04
4         1        5 2020-01-05
8         2        9 2020-01-09
9         2       10 2020-01-10
13        3        1 2020-01-04
14        3        2 2020-01-05
__init__(N, divide_column='query_id', time_column_format='yyyy-MM-dd HH:mm:ss', strategy='interactions', drop_cold_users=False, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')
Parameters
  • N (int) – Number of last interactions or size of the time window in seconds

  • divide_column (str) – Name of column for dividing in dataframe, default: query_id.

  • time_column_format (str) – Format of the timestamp column, used for converting string dates to a numerical timestamp when strategy is ‘timedelta’. If the column is already a datetime object or a numerical timestamp, this parameter is ignored. default: yyyy-MM-dd HH:mm:ss

  • strategy (Literal['interactions', 'timedelta']) – Defines the type of data splitting. Must be interactions or timedelta. default: interactions.

  • query_column (str) – Name of query interaction column.

  • drop_cold_users (bool) – Drop users from test DataFrame. which are not in train DataFrame, default: False.

  • drop_cold_items (bool) – Drop items from test DataFrame which are not in train DataFrame, default: False.

  • item_column (str) – Name of item interaction column. If drop_cold_items is False, then you can omit this parameter. Default: item_id.

  • timestamp_column (str) – Name of time column, Default: timestamp.

  • session_id_column (Optional[str]) – Name of session id column, which values can not be split, default: None.

  • session_id_processing_strategy (str) – strategy of processing session if it is split, values: train, test, train: whole split session goes to train. test: same but to test. default: test.

RatioSplitter

class replay.splitters.ratio_splitter.RatioSplitter(test_size, divide_column='query_id', drop_cold_users=False, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', min_interactions_per_group=None, split_by_fractions=True, session_id_column=None, session_id_processing_strategy='test')

Split interactions into train and test by ratio. Split is made for each user separately.

>>> from datetime import datetime
>>> import pandas as pd
>>> columns = ["query_id", "item_id", "timestamp"]
>>> data = [
...     (1, 1, "01-01-2020"),
...     (1, 2, "02-01-2020"),
...     (1, 3, "03-01-2020"),
...     (1, 4, "04-01-2020"),
...     (1, 5, "05-01-2020"),
...     (2, 1, "06-01-2020"),
...     (2, 2, "07-01-2020"),
...     (2, 3, "08-01-2020"),
...     (2, 9, "09-01-2020"),
...     (2, 10, "10-01-2020"),
...     (3, 1, "01-01-2020"),
...     (3, 5, "02-01-2020"),
...     (3, 3, "03-01-2020"),
...     (3, 1, "04-01-2020"),
...     (3, 2, "05-01-2020"),
... ]
>>> dataset = pd.DataFrame(data, columns=columns)
>>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
>>> dataset
    query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
2         1        3 2020-01-03
3         1        4 2020-01-04
4         1        5 2020-01-05
5         2        1 2020-01-06
6         2        2 2020-01-07
7         2        3 2020-01-08
8         2        9 2020-01-09
9         2       10 2020-01-10
10        3        1 2020-01-01
11        3        5 2020-01-02
12        3        3 2020-01-03
13        3        1 2020-01-04
14        3        2 2020-01-05
>>> splitter = RatioSplitter(
...     test_size=0.5,
...     divide_column="query_id",
...     query_column="query_id",
...     item_column="item_id"
... )
>>> train, test = splitter.split(dataset)
>>> train
   query_id  item_id  timestamp
0         1        1 2020-01-01
1         1        2 2020-01-02
5         2        1 2020-01-06
6         2        2 2020-01-07
10        3        1 2020-01-01
11        3        5 2020-01-02
>>> test
   query_id  item_id  timestamp
2         1        3 2020-01-03
3         1        4 2020-01-04
4         1        5 2020-01-05
7         2        3 2020-01-08
8         2        9 2020-01-09
9         2       10 2020-01-10
12        3        3 2020-01-03
13        3        1 2020-01-04
14        3        2 2020-01-05
__init__(test_size, divide_column='query_id', drop_cold_users=False, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', min_interactions_per_group=None, split_by_fractions=True, session_id_column=None, session_id_processing_strategy='test')
Parameters
  • ratio – test size, must be in \((0, 1)\).

  • divide_column (str) – Name of column for dividing in dataframe, default: query_id.

  • drop_cold_users (bool) – Drop users from test DataFrame. which are not in train DataFrame, default: False.

  • drop_cold_items (bool) – Drop items from test DataFrame which are not in train DataFrame, default: False.

  • query_column (str) – Name of query interaction column. If drop_cold_users is False, then you can omit this parameter. Default: query_id.

  • item_column (str) – Name of item interaction column. If drop_cold_items is False, then you can omit this parameter. Default: item_id.

  • timestamp_column (str) – Name of time column, Default: timestamp.

  • min_interactions_per_group (Optional[int]) – minimal required interactions per group to make first split. if value is less than min_interactions_per_group, than whole group goes to train. If not set, than any amount of interactions will be split. default: None.

  • split_by_fractions (bool) – the variable that is responsible for using the split by fractions. Split by fractions means that each line is marked with its fraq (line number / number of lines) and only those lines with a fraq > test_ratio get into the test. Split not by fractions means that the number of rows in the train is calculated by rounding the formula: the total number of rows minus the number of rows multiplied by the test ratio. The difference between these two methods is that due to rounding in the second method, 1 more interaction in each group (1 item for each user) falls into the train. When split by fractions, these items fall into the test. default: True.

  • session_id_column (Optional[str]) – Name of session id column, which values can not be split, default: None.

  • session_id_processing_strategy (str) – strategy of processing session if it is split, values: train, test, train: whole split session goes to train. test: same but to test. default: test.

RandomSplitter

class replay.splitters.random_splitter.RandomSplitter(test_size, drop_cold_items=False, drop_cold_users=False, seed=None, query_column='query_id', item_column='item_id')

Assign records into train and test at random.

__init__(test_size, drop_cold_items=False, drop_cold_users=False, seed=None, query_column='query_id', item_column='item_id')
Parameters
  • test_size (float) – test size 0 to 1

  • drop_cold_items (bool) – flag to drop cold items from test

  • drop_cold_users (bool) – flag to drop cold users from test

  • seed (Optional[int]) – random seed

  • query_column (str) – Name of query interaction column

  • item_column (str) – Name of item interaction column

NewUsersSplitter

class replay.splitters.new_users_splitter.NewUsersSplitter(test_size, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')

Only new users will be assigned to test set. Splits interactions by timestamp so that test has test_size fraction of most recent users.

>>> from replay.splitters import NewUsersSplitter
>>> import pandas as pd
>>> data_frame = pd.DataFrame({"query_id": [1,1,2,2,3,4],
...    "item_id": [1,2,3,1,2,3],
...    "relevance": [1,2,3,4,5,6],
...    "timestamp": [20,40,20,30,10,40]})
>>> data_frame
   query_id   item_id  relevance  timestamp
0         1         1          1         20
1         1         2          2         40
2         2         3          3         20
3         2         1          4         30
4         3         2          5         10
5         4         3          6         40
>>> train, test = NewUsersSplitter(test_size=0.1).split(data_frame)
>>> train
  query_id  item_id  relevance  timestamp
0        1        1          1         20
2        2        3          3         20
3        2        1          4         30
4        3        2          5         10

>>> test
  query_id  item_id  relevance  timestamp
0        4        3          6         40

Train DataFrame can be drastically reduced even with moderate test_size if the amount of new users is small.

>>> train, test = NewUsersSplitter(test_size=0.3).split(data_frame)
>>> train
  query_id  item_id  relevance  timestamp
4        3        2          5         10
__init__(test_size, drop_cold_items=False, query_column='query_id', item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')
Parameters
  • test_size (float) – test size 0 to 1

  • drop_cold_items (bool) – flag to drop cold items from test

  • query_column (str) – query id column name

  • item_column (Optional[str]) – item id column name

  • timestamp_column (Optional[str]) – timestamp column name

  • session_id_column (Optional[str]) – name of session id column, which values can not be split.

  • session_id_processing_strategy (str) – strategy of processing session if it is split, values: train, test, train: whole split session goes to train. test: same but to test. default: test.

ColdUserRandomSplitter

class replay.splitters.cold_user_random_splitter.ColdUserRandomSplitter(test_size, drop_cold_items=False, seed=None, query_column='query_id', item_column='item_id')

Test set consists of all actions of randomly chosen users.

__init__(test_size, drop_cold_items=False, seed=None, query_column='query_id', item_column='item_id')
Parameters
  • test_size (float) – The proportion of users to allocate to the test set. Must be a float between 0.0 and 1.0.

  • drop_cold_items (bool) – Drop items from test DataFrame which are not in train DataFrame, default: False.

  • seed (Optional[int]) – Seed for the random number generator to ensure reproducibility of the split, default: None.

  • query_column (str) – Name of query interaction column. default: query_id.

  • item_column (Optional[str]) – Name of item interaction column. default: item_id.

RandomNextNSplitter

class replay.splitters.random_next_n_splitter.RandomNextNSplitter(N=1, divide_column='query_id', seed=None, query_column='query_id', drop_cold_users=False, drop_cold_items=False, item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')

Split interactions by a random position in the user sequence. For each user, a random cut index is sampled and the target part consists of the next N interactions starting from this cut; the train part contains all interactions before the cut. Interactions after the target window are discarded.

Note: by changing the seed attribute on an existing splitter instance, you can obtain different splits without recreating the object. This is useful when you need to generate multiple randomized splits of the same dataset.

>>> from datetime import datetime
>>> import pandas as pd
>>> columns = ["query_id", "item_id", "timestamp"]
>>> data = [
...     (1, 1, "01-01-2020"),
...     (1, 2, "02-01-2020"),
...     (1, 3, "03-01-2020"),
...     (2, 1, "06-01-2020"),
...     (2, 2, "07-01-2020"),
...     (2, 3, "08-01-2020"),
... ]
>>> dataset = pd.DataFrame(data, columns=columns)
>>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
>>> splitter = RandomNextNSplitter(
...     N=2,
...     divide_column="query_id",
...     seed=42,
...     query_column="query_id",
...     item_column="item_id",
... )
>>> train, test = splitter.split(dataset)
__init__(N=1, divide_column='query_id', seed=None, query_column='query_id', drop_cold_users=False, drop_cold_items=False, item_column='item_id', timestamp_column='timestamp', session_id_column=None, session_id_processing_strategy='test')
Parameters
  • N (Optional[int]) – Optional window size. If None, the test set contains all interactions from the cut to the end; otherwise the next N interactions. Must be >= 1. Default: 1.

  • divide_column (str) – Name of the column used to group interactions for random cut sampling, default: query_id.

  • seed (Optional[int]) – Random seed used to sample cut indices, default: None.

  • query_column (str) – Name of query interaction column.

  • drop_cold_users (bool) – Drop users from test DataFrame which are not in the train DataFrame, default: False.

  • drop_cold_items (bool) – Drop items from test DataFrame which are not in the train DataFrame, default: False.

  • item_column (str) – Name of item interaction column. If drop_cold_items is False, then you can omit this parameter. Default: item_id.

  • timestamp_column (str) – Name of time column, default: timestamp.

  • session_id_column (Optional[str]) – Name of session id column whose values cannot be split between train/test, default: None.

  • session_id_processing_strategy (str) – Strategy to process a session if it crosses the boundary: train or test. train means the whole session goes to train, test — the whole session goes to test. Default: test.