recpack.algorithms.loss_functions.warp_loss
- recpack.algorithms.loss_functions.warp_loss(dist_pos_interaction: torch.Tensor, dist_neg_interaction: torch.Tensor, margin: float, num_items: int, num_negatives: int) torch.Tensor
WARP loss
WARP loss as described in Cheng-Kang Hsieh et al., Collaborative Metric Learning. WWW2017 http://www.cs.cornell.edu/~ylongqi/paper/HsiehYCLBE17.pdf based on J. Weston, S. Bengio, and N. Usunier. Large scale image annotation: learning to rank with joint word-image embeddings. Machine learning, 81(1):21–35, 2010.
Adds a loss penalty for every negative sample that is not at least an amount of margin further away from the reference sample than a positive sample. This per sample loss penalty has a weight proportional to the amount of samples in the negative sample batch were “misclassified”, i.e. closer than the positive sample.
- Parameters
dist_pos_interaction (torch.Tensor) – Tensor of distances between positive sample and reference sample.
dist_neg_interaction (torch.Tensor) – Tensor of distances between negatives samples and reference sample.
margin (float) – Required margin between positive and negative sample.
num_items (int) – Total number of items in the dataset. (J in the paper)
num_items – Number of negative samples used for every positive sample. (U in the paper)
- Returns
0-D Tensor containing WARP loss.
- Return type
torch.Tensor