recpack.algorithms.loss_functions.warp_loss(dist_pos_interaction: torch.Tensor, dist_neg_interaction: torch.Tensor, margin: float, num_items: int, num_negatives: int) torch.Tensor

WARP loss

WARP loss as described in Cheng-Kang Hsieh et al., Collaborative Metric Learning. WWW2017 based on J. Weston, S. Bengio, and N. Usunier. Large scale image annotation: learning to rank with joint word-image embeddings. Machine learning, 81(1):21–35, 2010.

Adds a loss penalty for every negative sample that is not at least an amount of margin further away from the reference sample than a positive sample. This per sample loss penalty has a weight proportional to the amount of samples in the negative sample batch were “misclassified”, i.e. closer than the positive sample.

  • dist_pos_interaction (torch.Tensor) – Tensor of distances between positive sample and reference sample.

  • dist_neg_interaction (torch.Tensor) – Tensor of distances between negatives samples and reference sample.

  • margin (float) – Required margin between positive and negative sample.

  • num_items (int) – Total number of items in the dataset. (J in the paper)

  • num_items – Number of negative samples used for every positive sample. (U in the paper)


0-D Tensor containing WARP loss.

Return type