class Wav2Vec2Model(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2Config): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2Config, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
class Wav2Vec2Model(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2Config): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_before = cfg.mask_channel_before self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2Config, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask( self, x, padding_mask, mask_indices=None, mask_channel_indices=None, ): B, T, C = x.shape if self.mask_channel_prob > 0 and self.mask_channel_before: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 if self.mask_prob > 0: if mask_indices is None: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0 and not self.mask_channel_before: if mask_channel_indices is None: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x = index_put(x, mask_channel_indices, 0) return x, mask_indices def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits = logits / self.logit_temp if is_xla_tensor(logits) or neg_is_pos.any(): fillval = -float(2**30) if not hasattr(self, "_inftensor"): self._inftensor = (torch.tensor(fillval).to(x.device) if is_xla_tensor(logits) else float("-inf")) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): return torch.floor((input_length - kernel_size) / stride + 1) conv_cfg_list = eval(self.cfg.conv_feature_layers) for i in range(len(conv_cfg_list)): input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) return input_lengths.to(torch.long) def forward( self, source, padding_mask=None, mask=True, features_only=False, layer=None, mask_indices=None, mask_channel_indices=None, padding_count=None, ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None and padding_mask.any(): input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = self._get_feat_extract_output_lengths( input_lengths) padding_mask = torch.zeros(features.shape[:2], dtype=features.dtype, device=features.device) # these two operations makes sure that all values # before the output lengths indices are attended to padding_mask[( torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1, )] = 1 padding_mask = ( 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() else: padding_mask = None if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) if not is_xla_tensor(x) and mask_indices is not None: # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) if features_only: return { "x": x, "padding_mask": padding_mask, "features": unmasked_features, "layer_results": layer_results, } if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands = self.quantizer(unmasked_features, produce_targets=False)["x"] negs, _ = self.sample_negatives( neg_cands, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives( unmasked_features, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if not is_xla_tensor(x): # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen, } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False, layer=None): res = self.forward(source, padding_mask, mask=mask, features_only=True, layer=layer) return res def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
class Wav2Vec2Model(BaseFairseqModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--extractor-mode", choices=["default", "layer_norm"], help= "mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", ) parser.add_argument( "--encoder-layers", type=int, metavar="L", help="num encoder layers in the transformer", ) parser.add_argument( "--encoder-embed-dim", type=int, metavar="H", help="encoder embedding dimension", ) parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="F", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="A", help="num encoder attention heads", ) parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability for the transformer", ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights", ) parser.add_argument( "--activation-dropout", type=float, metavar="D", help="dropout probability after activation in FFN", ) parser.add_argument( "--final-dim", type=int, metavar="D", help= "project final representations and targets to this many dimensions", ) parser.add_argument( "--layer-norm-first", action="store_true", help="apply layernorm first in the transformer", ) parser.add_argument( "--encoder-layerdrop", type=float, help="probability of dropping a tarnsformer layer", ) parser.add_argument( "--conv-feature-layers", type=str, metavar="EXPR", help= "convolutional feature extraction layers [(dim, kernel_size, stride), ...]", ) parser.add_argument("--logit-temp", type=float, help="temperature to divide logits by") parser.add_argument("--quantize-targets", action="store_true", help="use quantized targets") parser.add_argument("--quantize-input", action="store_true", help="use quantized inputs") parser.add_argument( "--feature-grad-mult", type=float, help="multiply feature extractor var grads by this", ) parser.add_argument( "--latent-vars", type=int, metavar="N", help="number of latent variables V in each group of the codebook", ) parser.add_argument( "--latent-groups", type=int, metavar="N", help="number of groups G of latent variables in the codebook", ) parser.add_argument( "--latent-dim", type=int, metavar="N", help= "if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", ) parser.add_argument("--mask-length", type=int, help="mask length") parser.add_argument("--mask-prob", type=float, help="probability of replacing a token with mask") parser.add_argument( "--mask-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--mask-channel-length", type=int, help="repeat the mask indices multiple times", ) parser.add_argument( "--mask-channel-prob", type=float, help="probability of replacing a token with mask", ) parser.add_argument( "--mask-channel-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-channel-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-channel-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-channel-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--dropout-input", type=float, metavar="D", help="dropout to apply to the input (after feat extr)", ) parser.add_argument( "--dropout-features", type=float, metavar="D", help="dropout to apply to the features (after feat extr)", ) parser.add_argument("--num-negatives", type=int, metavar="N", help="number of negative examples") parser.add_argument( "--negatives-from-everywhere", action="store_true", help="sample negatives from everywhere, not just masked states", ) parser.add_argument( "--cross-sample-negatives", type=int, metavar="N", help="num of cross sampled negatives", ) parser.add_argument( "--codebook-negatives", type=int, metavar="N", help="num of codebook sampled negatives", ) parser.add_argument( "--conv-pos", type=int, metavar="N", help="number of filters for convolutional positional embeddings", ) parser.add_argument( "--conv-pos-groups", type=int, metavar="N", help="number of groups for convolutional positional embedding", ) parser.add_argument( "--latent-temp", type=str, metavar="D", help= "temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", ) parser.add_argument("--target-glu", action="store_true", help="adds projection + glu to targets") parser.add_argument("--conv-bias", action="store_true", help="include bias in conv encoder") def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim and not args.quantize_input else None) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp if args.quantize_input: vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim self.input_quantizer = (GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.same_quantizer else self.quantizer) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
class DeCoAr2Model(BaseFairseqModel): def __init__(self, args): super().__init__() self.args = args self.post_extract_proj = nn.Linear(80, args.encoder_embed_dim) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp if args.quantize_input: vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim self.input_quantizer = (GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.same_quantizer else self.quantizer) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.post_extract_proj(source) unmasked_features = features.clone() features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None