Bert4Rec

Bert4Rec

class replay.models.nn.sequential.Bert4Rec(tensor_schema, block_count=2, head_count=4, hidden_size=256, max_seq_len=100, dropout_rate=0.1, pass_per_transformer_block_count=1, enable_positional_embedding=True, enable_embedding_tying=False, loss_type='CE', loss_sample_count=None, negative_sampling_strategy='global_uniform', negatives_sharing=False, optimizer_factory=<replay.models.nn.optimizer_utils.optimizer_factory.FatOptimizerFactory object>, lr_scheduler_factory=None)

Implements BERT training-validation loop

__init__(tensor_schema, block_count=2, head_count=4, hidden_size=256, max_seq_len=100, dropout_rate=0.1, pass_per_transformer_block_count=1, enable_positional_embedding=True, enable_embedding_tying=False, loss_type='CE', loss_sample_count=None, negative_sampling_strategy='global_uniform', negatives_sharing=False, optimizer_factory=<replay.models.nn.optimizer_utils.optimizer_factory.FatOptimizerFactory object>, lr_scheduler_factory=None)
Parameters
  • tensor_schema (TensorSchema) – Tensor schema of features.

  • block_count (int) – Number of Transformer blocks. Default: 2.

  • head_count (int) – Number of Attention heads. Default: 4.

  • hidden_size (int) – Hidden size of transformer. Default: 256.

  • max_seq_len (int) – Max length of sequence. Default: 100.

  • dropout_rate (float) – Dropout rate. Default: 0.1.

  • pass_per_transformer_block_count (int) – Number of times to pass data over each Transformer block. Default: 1.

  • enable_positional_embedding (bool) – Add positional embedding to the result. Default: True.

  • enable_embedding_tying (bool) – Use embedding tying head. If True - result scores are calculated by dot product of input and output embeddings, if False - default linear layer is applied to calculate logits for each item. Default: False.

  • loss_type (Literal['BCE', 'CE', 'CE_restricted']) – Loss type. Default: CE.

  • loss_sample_count (Optional[int]) – Sample count to calculate loss. Default: None.

  • negative_sampling_strategy (Literal['global_uniform', 'inbatch']) –

    Negative sampling strategy to calculate loss on sampled negatives. Is used when large count of items in dataset.

    Default: global_uniform.

  • negatives_sharing (bool) –

    Apply negative sharing in calculating sampled logits.

    Default: False.

  • optimizer_factory (OptimizerFactory) –

    Optimizer factory.

    Default: FatOptimizerFactory.

  • lr_scheduler_factory (Optional[LRSchedulerFactory]) –

    Learning rate schedule factory.

    Default: None.

forward(feature_tensors, padding_mask, tokens_mask, candidates_to_score=None)
Parameters
  • feature_tensors (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 otherwise.

  • tokens_mask (BoolTensor) – Token mask where 0 - <MASK> tokens, 1 otherwise.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Calculated scores.

Return type

Tensor

predict(batch, candidates_to_score=None)
Parameters
  • batch (Union[Bert4RecPredictionBatch, dict]) – Batch of prediction data.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Calculated scores on prediction batch.

Return type

Tensor

Bert4RecModel

class replay.models.nn.sequential.bert4rec.Bert4RecModel(schema, max_len=100, hidden_size=256, num_blocks=2, num_heads=4, num_passes_over_block=1, dropout=0.1, enable_positional_embedding=True, enable_embedding_tying=False)

BERT model

__init__(schema, max_len=100, hidden_size=256, num_blocks=2, num_heads=4, num_passes_over_block=1, dropout=0.1, enable_positional_embedding=True, enable_embedding_tying=False)
Parameters
  • schema (TensorSchema) – Tensor schema of features.

  • max_len (int) – Max length of sequence. Default: 100.

  • hidden_size (int) – Hidden size of transformer. Default: 256.

  • num_blocks (int) – Number of Transformer blocks. Default: 2.

  • num_heads (int) – Number of Attention heads. Default: 4.

  • num_passes_over_block (int) – Number of times to pass data over each Transformer block. Default: 1.

  • dropout (float) – Dropout rate. Default: 0.1.

  • enable_positional_embedding (bool) – Add positional embedding to the result. Default: True.

  • enable_embedding_tying (bool) – Use embedding tying head. Default: False.

forward(inputs, pad_mask, token_mask)
Parameters
  • inputs (Mapping[str, Tensor]) – Batch of features.

  • pad_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

  • token_mask (BoolTensor) – Token mask where 0 - <MASK> tokens, 1 - otherwise.

Returns

Calculated scores.

Return type

Tensor

forward_step(inputs, pad_mask, token_mask)
Parameters
  • (TensorMap) (inputs) – Batch of features.

  • (torch.BoolTensor) (token_mask) – Padding mask where 0 - <PAD>, 1 - otherwise.

  • (torch.BoolTensor) – Token mask where 0 - <MASK> tokens, 1 - otherwise.

Returns

Output embeddings.

Return type

Tensor

get_logits(out_embeddings, item_ids=None)

Apply head to output embeddings of forward_step.

Parameters
  • out_embeddings (Tensor) – Embeddings after forward step.

  • item_ids (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Logits for each element in item_ids.

Return type

Tensor

get_query_embeddings(inputs, pad_mask, token_mask)
Parameters
  • inputs (Mapping[str, Tensor]) – Batch of features.

  • pad_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

  • token_mask (BoolTensor) – Token mask where 0 - <MASK> tokens, 1 - otherwise.

Returns

Query embeddings.

predict(inputs, pad_mask, token_mask, candidates_to_score=None)
Parameters
  • inputs (Mapping[str, Tensor]) – Batch of features.

  • pad_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

  • token_mask (BoolTensor) – Token mask where 0 - <MASK> tokens, 1 - otherwise.

  • candidates_to_score (Optional[LongTensor]) –

    Item ids to calculate scores.

    If None then predicts for all items. Default: None.

Returns

Calculated scores among canditates_to_score items.

Return type

Tensor

Bert4RecTrainingDataset

class replay.models.nn.sequential.bert4rec.Bert4RecTrainingDataset(sequential, max_sequence_length, mask_prob=0.15, sliding_window_step=None, label_feature_name=None, custom_masker=None, padding_value=None)

Dataset that generates samples to train Bert4Rec model.

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. There are also additional keys needed to calculate losses - ‘positive_labels`. The query_id key is required for possible debugging and calling additional lightning callbacks.

__init__(sequential, max_sequence_length, mask_prob=0.15, sliding_window_step=None, label_feature_name=None, custom_masker=None, padding_value=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with training data.

  • max_sequence_length (int) – Max length of sequence.

  • mask_prob (float) – Probability of masking each token in sequence. Default: 0.15.

  • sliding_window_step (Optional[int]) – A sliding window step. If not None provides iteration over sequences with window. Default: None.

  • label_feature_name (Optional[str]) – Name of label feature in provided dataset. If None set an item_id_feature name from sequential dataset. Default: None.

  • custom_masker (Optional[Bert4RecMasker]) – Masker object to generate masks for Bert4Rec training. If None set a Bert4RecUniformMasker with provided mask_prob. Default: None.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

Bert4RecValidationDataset

class replay.models.nn.sequential.bert4rec.Bert4RecValidationDataset(sequential, ground_truth, train, max_sequence_length, padding_value=None, label_feature_name=None)

Dataset that generates samples to infer and validate BERT-like model

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. The query_id key is required for possible debugging and calling additional lightning callbacks. Keys ‘ground_truth` and train keys are required for metrics calculation on validation stage.

__init__(sequential, ground_truth, train, max_sequence_length, padding_value=None, label_feature_name=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with data to make predictions at.

  • ground_truth (SequentialDataset) – Sequential dataset with ground truth predictions.

  • train (SequentialDataset) – Sequential dataset with training data.

  • max_sequence_length (int) – Max length of sequence.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

  • label_feature_name (Optional[str]) – Name of label feature in provided dataset. If None set an item_id_feature name from sequential dataset. Default: None.

Bert4RecPredictionDataset

class replay.models.nn.sequential.bert4rec.Bert4RecPredictionDataset(sequential, max_sequence_length, padding_value=None)

Dataset that generates samples to inference Bert4Rec model

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. The query_id key is required for possible debugging and calling additional lightning callbacks.

__init__(sequential, max_sequence_length, padding_value=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with data to make predictions at.

  • max_sequence_length (int) – Max length of sequence.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

Bert4RecTrainingBatch

class replay.models.nn.sequential.bert4rec.Bert4RecTrainingBatch(query_id, padding_mask, features, tokens_mask, labels)

Batch of data for training. Generated by Bert4RecTrainingDataset.

features: Mapping[str, Tensor]

Alias for field number 2

labels: LongTensor

Alias for field number 4

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

tokens_mask: BoolTensor

Alias for field number 3

Bert4RecValidationBatch

class replay.models.nn.sequential.bert4rec.Bert4RecValidationBatch(query_id, padding_mask, features, tokens_mask, ground_truth, train)

Batch of data for validation. Generated by Bert4RecValidationDataset.

features: Mapping[str, Tensor]

Alias for field number 2

ground_truth: LongTensor

Alias for field number 4

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

tokens_mask: BoolTensor

Alias for field number 3

train: LongTensor

Alias for field number 5

Bert4RecPredictionBatch

class replay.models.nn.sequential.bert4rec.Bert4RecPredictionBatch(query_id, padding_mask, features, tokens_mask)

Batch of data for model inference. Generated by Bert4RecPredictionDataset.

features: Mapping[str, Tensor]

Alias for field number 2

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

tokens_mask: BoolTensor

Alias for field number 3

SasRec (legacy)

SasRec

class replay.models.nn.sequential.SasRec(tensor_schema, block_count=2, head_count=1, hidden_size=50, max_seq_len=200, dropout_rate=0.2, ti_modification=False, time_span=256, loss_type='CE', loss_sample_count=None, negative_sampling_strategy='global_uniform', negatives_sharing=False, optimizer_factory=<replay.models.nn.optimizer_utils.optimizer_factory.FatOptimizerFactory object>, lr_scheduler_factory=None, sce_params=None)

SASRec Lightning module.

You can get initialization parameters with attribute hparams for object of SasRec instance.

__init__(tensor_schema, block_count=2, head_count=1, hidden_size=50, max_seq_len=200, dropout_rate=0.2, ti_modification=False, time_span=256, loss_type='CE', loss_sample_count=None, negative_sampling_strategy='global_uniform', negatives_sharing=False, optimizer_factory=<replay.models.nn.optimizer_utils.optimizer_factory.FatOptimizerFactory object>, lr_scheduler_factory=None, sce_params=None)
Parameters
  • tensor_schema (TensorSchema) – Tensor schema of features.

  • block_count (int) – Number of Transformer blocks. Default: 2.

  • head_count (int) – Number of Attention heads. Default: 1.

  • hidden_size (int) – Hidden size of transformer. Default: 50.

  • max_seq_len (int) – Max length of sequence. Default: 200.

  • dropout_rate (float) – Dropout rate. Default: 0.2.

  • ti_modification (bool) – Enable time relation. Default: False.

  • time_span (int) – Time span value. Default: 256.

  • loss_type (Literal['BCE', 'CE', 'SCE']) – Loss type. Default: CE.

  • loss_sample_count (Optional[int]) – Sample count to calculate loss. Suitable for "CE" and "BCE" loss functions. Default: None.

  • negative_sampling_strategy (str) – Negative sampling strategy to calculate loss on sampled negatives. Is used when large count of items in dataset. Possible values: "global_uniform", "inbatch" Default: global_uniform.

  • negatives_sharing (bool) – Apply negative sharing in calculating sampled logits. Default: False.

  • optimizer_factory (OptimizerFactory) – Optimizer factory. Default: FatOptimizerFactory.

  • lr_scheduler_factory (Optional[LRSchedulerFactory]) – Learning rate schedule factory. Default: None.

  • sce_params (Optional[SCEParams]) – Dataclass with SCE parameters. Need to be defined if loss_type is SCE. Default: None.

forward(feature_tensors, padding_mask, candidates_to_score=None)
Parameters
  • feature_tensors (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 otherwise.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Calculated scores.

Return type

Tensor

predict(batch, candidates_to_score=None)
Parameters
  • batch (Union[SasRecPredictionBatch, dict]) – Batch of prediction data.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Calculated scores.

Return type

Tensor

SasRecModel

class replay.models.nn.sequential.sasrec.SasRecModel(schema, num_blocks=2, num_heads=1, hidden_size=50, max_len=200, dropout=0.2, ti_modification=False, time_span=256)

SasRec model

__init__(schema, num_blocks=2, num_heads=1, hidden_size=50, max_len=200, dropout=0.2, ti_modification=False, time_span=256)
Parameters
  • schema (TensorSchema) – Tensor schema of features.

  • num_blocks (int) – Number of Transformer blocks. Default: 2.

  • num_heads (int) – Number of Attention heads. Default: 1.

  • hidden_size (int) – Hidden size of transformer. Default: 50.

  • max_len (int) – Max length of sequence. Default: 200.

  • dropout (float) – Dropout rate. Default: 0.2.

  • ti_modification (bool) – Enable time relation. Default: False.

  • time_span (int) – Time span if ti_modification is True. Default: 256.

forward(feature_tensor, padding_mask)
Parameters
  • feature_tensor (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

Returns

Calculated scores.

Return type

Tensor

forward_step(feature_tensor, padding_mask)
Parameters
  • feature_tensor (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

Returns

Output embeddings.

Return type

Tensor

get_logits(out_embeddings, item_ids=None)

Apply head to output embeddings of forward_step.

Parameters
  • out_embeddings (Tensor) – Embeddings after forward step.

  • item_ids (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Logits for each element in item_ids.

Return type

Tensor

get_query_embeddings(feature_tensor, padding_mask)
Parameters
  • feature_tensor (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

Returns

Query embeddings.

predict(feature_tensor, padding_mask, candidates_to_score=None)
Parameters
  • feature_tensor (Mapping[str, Tensor]) – Batch of features.

  • padding_mask (BoolTensor) – Padding mask where 0 - <PAD>, 1 - otherwise.

  • candidates_to_score (Optional[LongTensor]) –

    Item ids to calculate scores.

    If None then predicts for all items. Default: None.

Returns

Prediction among canditates_to_score items.

Return type

Tensor

SasRecTrainingDataset

class replay.models.nn.sequential.sasrec.SasRecTrainingDataset(sequential, max_sequence_length, sequence_shift=1, sliding_window_step=None, padding_value=None, label_feature_name=None)

Dataset that generates samples to train SasRec model.

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. There are also additional keys needed to calculate losses - ‘positive_labels`, target_padding_mask. The query_id key is required for possible debugging and calling additional lightning callbacks.

__init__(sequential, max_sequence_length, sequence_shift=1, sliding_window_step=None, padding_value=None, label_feature_name=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with training data.

  • max_sequence_length (int) – Max length of sequence.

  • sequence_shift (int) – Shift of sequence to predict.

  • sliding_window_step (None) – A sliding window step. If not None provides iteration over sequences with window. Default: None.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

  • label_feature_name (Optional[str]) – Name of label feature in provided dataset. If None set an item_id_feature name from sequential dataset. Default: None.

SasRecValidationDataset

class replay.models.nn.sequential.sasrec.SasRecValidationDataset(sequential, ground_truth, train, max_sequence_length, padding_value=None, label_feature_name=None)

Dataset that generates samples to infer and validate SasRec model.

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. The query_id key is required for possible debugging and calling additional lightning callbacks. Keys ‘ground_truth` and train keys are required for metrics calculation on validation stage.

__init__(sequential, ground_truth, train, max_sequence_length, padding_value=None, label_feature_name=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with data to make predictions at.

  • ground_truth (SequentialDataset) – Sequential dataset with ground truth predictions.

  • train (SequentialDataset) – Sequential dataset with training data.

  • max_sequence_length (int) – Max length of sequence.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

  • label_feature_name (Optional[str]) – Name of label feature in provided dataset. If None set an item_id_feature name from sequential dataset. Default: None.

SasRecPredictionDataset

class replay.models.nn.sequential.sasrec.SasRecPredictionDataset(sequential, max_sequence_length, padding_value=None)

Dataset that generates samples to infer SasRec model

As a result of the dataset iteration, a dictionary is formed. The keys in the dictionary match the names of the arguments in the model’s forward function. The query_id key is required for possible debugging and calling additional lightning callbacks.

__init__(sequential, max_sequence_length, padding_value=None)
Parameters
  • sequential (SequentialDataset) – Sequential dataset with data to make predictions at.

  • max_sequence_length (int) – Max length of sequence.

  • padding_value (Optional[int]) – Value for padding a sequence to match the max_sequence_length. Default: 0.

SasRecTrainingBatch

class replay.models.nn.sequential.sasrec.SasRecTrainingBatch(query_id, padding_mask, features, labels, labels_padding_mask)

Batch of data for training. Generated by SasRecTrainingDataset.

features: Mapping[str, Tensor]

Alias for field number 2

labels: LongTensor

Alias for field number 3

labels_padding_mask: BoolTensor

Alias for field number 4

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

SasRecValidationBatch

class replay.models.nn.sequential.sasrec.SasRecValidationBatch(query_id, padding_mask, features, ground_truth, train)

Batch of data for validation. Generated by SasRecValidationDataset.

features: Mapping[str, Tensor]

Alias for field number 2

ground_truth: LongTensor

Alias for field number 3

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

train: LongTensor

Alias for field number 4

SasRecPredictionBatch

class replay.models.nn.sequential.sasrec.SasRecPredictionBatch(query_id, padding_mask, features)

Batch of data for model inference. Generated by SasRecPredictionDataset.

features: Mapping[str, Tensor]

Alias for field number 2

padding_mask: BoolTensor

Alias for field number 1

query_id: LongTensor

Alias for field number 0

Compiled sequential models

Sequential models like SasRec and Bert4Rec can be converted to ONNX format for fast inference on CPU.

SasRecCompiled

class replay.models.nn.sequential.compiled.SasRecCompiled(compiled_model, schema)

SasRec CPU-optimized model for inference via OpenVINO. It is recommended to compile model with compile method and pass SasRec checkpoint or the model object itself into it. It is also possible to compile model by yourself and pass it to the __init__ with TensorSchema.

Note that compilation requires disk write (and maybe delete) permission.

classmethod compile(model, mode='one_query', batch_size=None, num_candidates_to_score=None, num_threads=None, onnx_path=None)

Model compilation.

Parameters
  • model (Union[SasRec, str, Path]) – Path to lightning SasRec model saved in .ckpt format or the SasRec object itself.

  • mode (Literal['batch', 'one_query', 'dynamic_batch_size']) –

    Inference mode, defines shape of inputs.

    one_query - sets input shape to [1, max_seq_len]

    batch - sets input shape to [batch_size, max_seq_len]

    dynamic_batch_size - sets batch_size to dynamic range [?, max_seq_len]

    Default: one_query.

  • batch_size (Optional[int]) – Batch size, required for batch mode. Default: None.

  • num_candidates_to_score (Optional[int]) –

    Number of item ids to calculate scores.

    Could be one of [None, -1, N].

    -1 - sets candidates_to_score shape to dynamic range [1, ?]

    N - sets candidates_to_score shape to [1, N]

    None - disable candidates_to_score usage

    Default: None.

  • num_threads (Optional[int]) – Number of CPU threads to use. Must be a natural number or None. If None, then compiler will set this parameter automatically. Default: None.

  • onnx_path (Optional[str]) – Save ONNX model to path, if defined. Default: None.

Return type

SasRecCompiled

predict(batch, candidates_to_score=None)

Inference on one batch.

Parameters
  • batch (Union[SasRecPredictionBatch, dict]) – Prediction input.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Tensor with scores.

Return type

Tensor

Bert4RecCompiled

class replay.models.nn.sequential.compiled.Bert4RecCompiled(compiled_model, schema)

Bert4Rec CPU-optimized model for inference via OpenVINO. It is recommended to compile model with compile method and pass Bert4Rec checkpoint or the model object itself into it. It is also possible to compile model by yourself and pass it to the __init__ with TensorSchema.

Note that compilation requires disk write (and maybe delete) permission.

classmethod compile(model, mode='one_query', batch_size=None, num_candidates_to_score=None, num_threads=None, onnx_path=None)

Model compilation.

Parameters
  • model (Union[Bert4Rec, str, Path]) – Path to lightning Bert4Rec model saved in .ckpt format or the Bert4Rec object itself.

  • mode (Literal['batch', 'one_query', 'dynamic_batch_size']) –

    Inference mode, defines shape of inputs. Could be one of [one_query, batch, dynamic_batch_size].

    one_query - sets input shape to [1, max_seq_len]

    batch - sets input shape to [batch_size, max_seq_len]

    dynamic_batch_size - sets batch_size to dynamic range [?, max_seq_len]

    Default: one_query.

  • batch_size (Optional[int]) – Batch size, required for batch mode. Default: None.

  • num_candidates_to_score (Optional[int]) –

    Number of item ids to calculate scores. Could be one of [None, -1, N].

    -1 - sets candidates_to_score shape to dynamic range [1, ?]

    N - sets candidates_to_score shape to [1, N]

    None - disables candidates_to_score usage

    Default: None.

  • num_threads (Optional[int]) – Number of CPU threads to use. Must be a natural number or None. If None, then compiler will set this parameter automatically. Default: None.

  • onnx_path (Optional[str]) – Save ONNX model to path, if defined. Default: None.

Return type

Bert4RecCompiled

predict(batch, candidates_to_score=None)

Inference on one batch.

Parameters
  • batch (Union[Bert4RecPredictionBatch, dict]) – Prediction input.

  • candidates_to_score (Optional[LongTensor]) – Item ids to calculate scores. Default: None.

Returns

Tensor with scores.

Return type

Tensor