class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn def compute_macs_params(self, T=1, S=1): macs = 0 n_params = 0 macs_attn = 0 # LayerNorm n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()]) n_params += sum([p.numel() for p in self.final_layer_norm.parameters()]) # self attention self_attn_layer = self.self_attn.compute_macs_params(T=T, S=T) macs += self_attn_layer['macs'] n_params += self_attn_layer['params'] macs_attn += self_attn_layer['macs_attn'] # Encoder-decoder attn if self.encoder_attn is not None: # self attention scaled-dot-product Attn enc_attn = self.encoder_attn.compute_macs_params(T=T, S=S) macs += enc_attn['macs'] n_params += enc_attn['params'] macs_attn += enc_attn['macs_attn'] # FFN fc1_params = sum([p.numel() for p in self.fc1.parameters()]) macs += (fc1_params * T) n_params += (fc1_params) fc2_params = sum([p.numel() for p in self.fc2.parameters()]) macs += (fc2_params * T) n_params += fc2_params return { 'name': self.__class__.__name__, 'macs': macs, 'params': n_params, 'macs_attn': macs_attn }
class TransformerEncoder(nn.Module): def __init__(self, args): super().__init__() self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.pos_conv = nn.Conv1d( self.embedding_dim, self.embedding_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt( (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) builder = AttentionBuilder.from_kwargs( attention_dropout=args.attention_dropout, clusters=args.clusters, iterations=args.iterations, topk=args.topk, bits=args.bits, hash_bias=args.hash_bias, length_limit=args.length_limit, query_dimensions=args.encoder_embed_dim / args.encoder_attention_heads, ) inner_attention = builder.get(args.attention_type) self.layers = nn.ModuleList([ TransformerSentenceEncoderLayer( inner_attention=inner_attention, embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) for _ in range(args.encoder_layers) ]) self.gradient_checkpoint = strtobool(args.gradient_checkpoint) self.n_freeze = args.n_freeze_layers self.n_remove = args.n_remove_layers self.n_active = args.encoder_layers - args.n_remove_layers for i in range(self.n_freeze): for p in self.layers[i].parameters(): p.requires_grad = False for i in range(self.n_remove): for p in self.layers[-(i + 1)].parameters(): p.requires_grad = False self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop if self.n_freeze > 0: for param in self.pos_conv.parameters(): param.requires_grad = False if not self.layer_norm_first: for param in self.layer_norm.parameters(): param.requires_grad = False self.apply(init_bert_params) def forward(self, x, padding_mask=None): x = self.extract_features(x, padding_mask) if self.layer_norm_first: x = self.layer_norm(x) return x def extract_features(self, x, padding_mask=None): if padding_mask is not None: x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x += x_conv if not self.layer_norm_first: x = self.layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) layer_results = [] length_mask = None if padding_mask is not None and torch.any(padding_mask): input_lengths = (~padding_mask).sum(-1) length_mask = LengthMask(lengths=input_lengths.long()) n_frozen = self.n_freeze attention_mask = None for i, layer in enumerate(self.layers): if i >= self.n_active: break dropout_probability = np.random.random() if i <= n_frozen and n_frozen > 0: x = layer(x, length_mask=length_mask) layer_results.append(x) else: if not self.training or (dropout_probability > self.layerdrop): if self.gradient_checkpoint: x = checkpoint(layer, x, attention_mask, length_mask) else: x = layer(x, length_mask=length_mask) layer_results.append(x) return x def max_positions(self): """Maximum output length supported by the encoder.""" return self.args.max_positions def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict
class TransformerEncoderLayer(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.encoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where T_tgt is the length of query, while T_src is the length of key, though here both query and key is x here, attn_mask[t_tgt, t_src] = 1 means when calculating embedding for t_tgt, t_src is excluded (or masked out), =0 means it is included in attention Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters # TODO: to formally solve this problem, we need to change fairseq's # MultiheadAttention. We will do this later on. x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) return x def compute_macs_params(self, S=1): macs = 0 n_params = 0 macs_attn = 0 # Layer Norms # MACS are zero for LayerNorm because they can be fused n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()]) n_params += sum([p.numel() for p in self.final_layer_norm.parameters()]) # Attn self_attn_layer = self.self_attn.compute_macs_params(T=S, S=S) macs += self_attn_layer['macs'] n_params += self_attn_layer['params'] macs_attn += self_attn_layer['macs_attn'] # FFN fc1_params = sum([p.numel() for p in self.fc1.parameters()]) # scale MACS by S because S tokens can be processed in parallel macs += (fc1_params * S) n_params += fc1_params fc2_params = sum([p.numel() for p in self.fc2.parameters()]) # scale MACS by S because S tokens can be processed in parallel macs += (fc2_params * S) n_params += fc2_params return { 'name': self.__class__.__name__, 'macs': macs, 'params': n_params, 'macs_attn': macs_attn }
class Wav2Vec2FTASRModel(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2FTASRConfig): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.n_freeze = cfg.n_freeze_layers self.n_remove = cfg.n_remove_layers self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) self.set_grad_to_false() def set_grad_to_false(self): if self.n_freeze > 0: for param in self.feature_extractor.parameters(): param.requires_grad = False for param in self.post_extract_proj.parameters(): param.requires_grad = False for param in self.layer_norm.parameters(): param.requires_grad = False for param in self.quantizer.parameters(): param.requires_grad = False for param in self.project_q.parameters(): param.requires_grad = False for param in self.final_proj.parameters(): param.requires_grad = False def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2FTASRConfig, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if len(source.shape) <= 2: if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features = features.transpose(1, 2) else: features = source features_pen = features.float().pow(2).mean() features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None