def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) logits = logits / self.logit_temp logits = logits.type_as(x) if is_xla_tensor(logits) or neg_is_pos.any(): if not hasattr(self, "_inftensor"): fillval = -float(2**30) self._inftensor = (torch.tensor(fillval).to(x.device) if is_xla_tensor(logits) else float("-inf")) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) self.xla = is_xla_tensor(logits) # XXX: handle weights on xla. weights = None if hasattr(model, "get_target_weights") and not self.infonce: weights = model.get_target_weights(target, net_output) if torch.is_tensor(weights): weights = weights.float() losses = [] reduction = "none" if ((not reduce) or self.xla) else "sum" if self.infonce: loss = F.cross_entropy(logits, target, reduction=reduction) else: loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction=reduction) if self.xla: # tpu-comment: since dynamic shapes lead to recompilations on xla, # we don't shrink tensors using mask_indices. # Instead, we use mask indices to adjust loss. mi = ( sample['net_input']['mask_indices'].transpose( 0, 1) # logits are transposed in `model.get_logits` .reshape(logits.size(0))) loss = (loss * mi).sum() if reduce else (loss * mi) if 'sample_size' in sample and self.infonce: sample_size = sample['sample_size'] elif 'mask_indices' in sample['net_input']: sample_size = sample['net_input']['mask_indices'].sum() else: sample_size = target.numel() if self.infonce else target.long( ).sum().item() losses.append(loss.detach().clone()) if self.loss_weights is not None: assert hasattr(model, "get_extra_losses") extra_losses = model.get_extra_losses(net_output) if torch.is_tensor(extra_losses): extra_losses = [extra_losses] if len(self.loss_weights) == 1 and len(extra_losses) != 1: self.loss_weights = [self.loss_weights[0]] * len(extra_losses) assert len(extra_losses) == len( self.loss_weights ), f"{len(extra_losses)}, {len(self.loss_weights)}" for p, coef in zip(extra_losses, self.loss_weights): if coef != 0 and p is not None: p = coef * p.float() * sample_size loss += p losses.append(p) logging_output = { "loss": loss.item() if (reduce and not self.xla) else loss.detach(), "ntokens": sample_size, "nsentences": sample["id"].numel(), "sample_size": sample_size, } for lk in self.log_keys: # Only store "logits" and "target" for computing MAP and MAUC # during validation if lk == "logits": if not self.training: logging_output["logits"] = logits.cpu().numpy() elif lk == "target": if not self.training: logging_output["target"] = target.cpu().numpy() elif lk in net_output: value = net_output[lk] if not is_xla_tensor(value): value = float(value) logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): logging_output[f"loss_{i}"] = l.item( ) if not self.xla else l.detach() if self.infonce: with torch.no_grad(): if logits.numel() == 0: corr = 0 count = 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 if is_xla_tensor(logits): max, min = max * mi, min * mi both = max & min corr = max.long().sum() - both.long().sum() count = mi.sum() else: both = max & min corr = max.long().sum().item() - both.long().sum( ).item() count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count return loss, sample_size, logging_output
def forward( self, source, padding_mask=None, mask=True, features_only=False, layer=None, mask_indices=None, mask_channel_indices=None, padding_count=None, ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None and padding_mask.any(): input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = self._get_feat_extract_output_lengths( input_lengths) padding_mask = torch.zeros(features.shape[:2], dtype=features.dtype, device=features.device) # these two operations makes sure that all values # before the output lengths indices are attended to padding_mask[( torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1, )] = 1 padding_mask = ( 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() else: padding_mask = None if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) if not is_xla_tensor(x) and mask_indices is not None: # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) if features_only: return { "x": x, "padding_mask": padding_mask, "features": unmasked_features, "layer_results": layer_results, } if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands = self.quantizer(unmasked_features, produce_targets=False)["x"] negs, _ = self.sample_negatives( neg_cands, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives( unmasked_features, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if not is_xla_tensor(x): # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen, } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result
def forward( self, source, padding_mask=None, mask=True, features_only=False, mask_indices=None, mask_channel_indices=None, padding_count=None, ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) if not is_xla_tensor(x) and mask_indices is not None: # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands = self.quantizer(unmasked_features, produce_targets=False)["x"] negs, _ = self.sample_negatives( neg_cands, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives( unmasked_features, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if not is_xla_tensor(x): # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result