class Wav2Vec2FTASRModel(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2FTASRConfig): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.n_freeze = cfg.n_freeze_layers self.n_remove = cfg.n_remove_layers self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) self.set_grad_to_false() def set_grad_to_false(self): if self.n_freeze > 0: for param in self.feature_extractor.parameters(): param.requires_grad = False for param in self.post_extract_proj.parameters(): param.requires_grad = False for param in self.layer_norm.parameters(): param.requires_grad = False for param in self.quantizer.parameters(): param.requires_grad = False for param in self.project_q.parameters(): param.requires_grad = False for param in self.final_proj.parameters(): param.requires_grad = False def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2FTASRConfig, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if len(source.shape) <= 2: if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features = features.transpose(1, 2) else: features = source features_pen = features.float().pow(2).mean() features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
class ABCModel(BaseFairseqModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--extractor-mode", choices=["default", "layer_norm"], help= "mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", ) parser.add_argument( "--encoder-layers", type=int, metavar="L", help="num encoder layers in the transformer", ) parser.add_argument( "--encoder-embed-dim", type=int, metavar="H", help="encoder embedding dimension", ) parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="F", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="A", help="num encoder attention heads", ) parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability for the transformer", ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights", ) parser.add_argument( "--activation-dropout", type=float, metavar="D", help="dropout probability after activation in FFN", ) parser.add_argument( "--final-dim", type=int, metavar="D", help= "project final representations and targets to this many dimensions", ) parser.add_argument( "--layer-norm-first", action="store_true", help="apply layernorm first in the transformer", ) parser.add_argument( "--encoder-layerdrop", type=float, help="probability of dropping a tarnsformer layer", ) parser.add_argument( "--conv-feature-layers", type=str, metavar="EXPR", help= "convolutional feature extraction layers [(dim, kernel_size, stride), ...]", ) parser.add_argument("--logit-temp", type=float, help="temperature to divide logits by") parser.add_argument("--quantize-targets", action="store_true", help="use quantized targets") parser.add_argument("--quantize-input", action="store_true", help="use quantized inputs") parser.add_argument( "--same-quantizer", action="store_true", help="use same quantizer for inputs and targets", ) parser.add_argument( "--feature-grad-mult", type=float, help="multiply feature extractor var grads by this", ) parser.add_argument( "--latent-vars", type=int, metavar="N", help="number of latent variables V in each group of the codebook", ) parser.add_argument( "--latent-groups", type=int, metavar="N", help="number of groups G of latent variables in the codebook", ) parser.add_argument( "--latent-dim", type=int, metavar="N", help= "if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", ) parser.add_argument("--mask-length", type=int, help="mask length") parser.add_argument("--mask-prob", type=float, help="probability of replacing a token with mask") parser.add_argument( "--mask-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--mask-channel-length", type=int, help="repeat the mask indices multiple times", ) parser.add_argument( "--mask-channel-prob", type=float, help="probability of replacing a token with mask", ) parser.add_argument( "--mask-channel-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-channel-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-channel-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-channel-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--dropout-input", type=float, metavar="D", help="dropout to apply to the input (after feat extr)", ) parser.add_argument( "--dropout-features", type=float, metavar="D", help="dropout to apply to the features (after feat extr)", ) parser.add_argument("--num-negatives", type=int, metavar="N", help="number of negative examples") parser.add_argument( "--negatives-from-everywhere", action="store_true", help="sample negatives from everywhere, not just masked states", ) parser.add_argument( "--cross-sample-negatives", type=int, metavar="N", help="num of cross sampled negatives", ) parser.add_argument( "--codebook-negatives", type=int, metavar="N", help="num of codebook sampled negatives", ) #parser.add_argument( # "--byol-all", # action="store_true", # help="apply byol to whole network" #) parser.add_argument("--base-decay", type=float, metavar="D", help="base decay value of target network") parser.add_argument("--projection-dim", type=int, metavar="N", help="byol projection network's output dimension") parser.add_argument("--prediction-dim", type=int, metavar="N", help="byol prediction network's output dimension") parser.add_argument("--byol-hidden-dim", type=int, metavar="N", help="hidden dimension of MLPs used for byol") parser.add_argument("--shared-quantizer", action="store_true", help="share quantizer with target network") parser.add_argument("--shared-emb", action="store_true", help="share mask embedding parameter") parser.add_argument("--mlp-prediction", action="store_true", help="use MLP as prediction network") parser.add_argument("--mlp-encoder", action="store_true", help="use MLP as encoder network") parser.add_argument( "--separate-mask-indices", action="store_true", help="use two separate mask indices for Transformers") parser.add_argument( "--conv-pos", type=int, metavar="N", help="number of filters for convolutional positional embeddings", ) parser.add_argument( "--conv-pos-groups", type=int, metavar="N", help="number of groups for convolutional positional embedding", ) parser.add_argument( "--latent-temp", type=str, metavar="D", help= "temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", ) parser.add_argument("--target-glu", action="store_true", help="adds projection + glu to targets") parser.add_argument("--conv-bias", action="store_true", help="include bias in conv encoder") def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.online_params = [] self.target_params = [] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.feature_extractor_target = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.online_params += list(self.feature_extractor.parameters()) self.target_params += list(self.feature_extractor_target.parameters()) if self.embed != args.encoder_embed_dim and not args.quantize_input: self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim)) self.post_extract_proj_target = (nn.Linear(self.embed, args.encoder_embed_dim)) self.online_params += list(self.post_extract_proj.parameters()) self.target_params += list( self.post_extract_proj_target.parameters()) else: self.post_extract_proj = None self.post_extract_proj_target = None self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.base_decay = args.base_decay self.step = 0 self.total_steps = args.total_num_update self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.shared_quantizer = args.shared_quantizer final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list(self.quantizer.parameters()) self.target_params += list(self.quantizer_target.parameters()) ### TODO: separate project_q? self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if args.quantize_input: if args.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer if not args.shared_quantizer: self.input_quantizer_target = self.quantizer_target else: vq_dim = (args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim) self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.input_quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list( self.input_quantizer.parameters()) self.target_params += list( self.input_quantizer_target.parameters()) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) if not args.shared_emb: self.mask_emb_target = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) #self.online_params += list(self.mask_emb) #self.target_params += list(self.mask_emb_target) if args.mlp_encoder: self.encoder = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) self.encoder_target = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) else: self.encoder = TransformerEncoder(args) self.encoder_target = TransformerEncoder(args) self.online_params += list(self.encoder.parameters()) self.target_params += list(self.encoder_target.parameters()) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.target_glu_target = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.online_params += list(self.target_glu.parameters()) self.target_params += list(self.target_glu_target.parameters()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) self.final_proj_target = nn.Linear(args.encoder_embed_dim, final_dim) self.online_params += list(self.final_proj.parameters()) self.target_params += list(self.final_proj_target.parameters()) ### TODO: Transformer prediction network? if args.mlp_prediction: self.prediction = MLPPrediction(final_dim, args.byol_hidden_dim) else: self.prediction = TransformerEncoder(args, args.final_dim) for param in self.target_params: param.requires_grad = False #def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: # gen = self._named_members( # lambda module: module._parameters.items(), # prefix=prefix, recurse=recurse) # for elem in gen: # if elem[1] not in self.target_params: # yield elem def update_target(self): decay = 1 - (1 - self.base_decay) * ( np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 self.step += 1 for online_param, target_param in zip(self.online_params, self.target_params): target_old = target_param.data target_param.data = decay * target_old + ( 1 - decay) * online_param.data def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask, mask_indices=None): B, T, C = x.shape if mask_indices is None: if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None else: x[mask_indices] = self.mask_emb if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.predict(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def predict(self, source, padding_mask=None, mask=True, features_only=False, mask_indices=None): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) #print("online feature", features.size()) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask, mask_indices) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) #print("online encoded", x.size()) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) else: y = self.project_q(y) #print("online before masking", x.size()) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) #print("online after masking", x.size()) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) #x = self.compute_preds(x, y, negs) #print("online before prediction", x.size()) if self.args.mlp_prediction: x = self.prediction(x) else: x = self.prediction(x, padding_mask=padding_mask) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result, mask_indices def target_predict(self, source, padding_mask=None, mask=True, mask_indices=None): with torch.no_grad(): if self.feature_grad_mult > 0: features = self.feature_extractor_target(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor_target(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj_target is not None: features = self.post_extract_proj_target(features) #print("target feature", features.size()) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: if self.shared_quantizer: q = self.input_quantizer(features, produce_targets=False) else: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] if self.shared_quantizer: features = self.project_inp(features) else: features = self.project_inp_target(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask, mask_indices) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder_target(x, padding_mask=padding_mask) #print("target encoded", x.size()) if self.quantizer: if self.shared_quantizer: q = self.quantizer(y, produce_targets=False) else: q = self.quantizer_target(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] if self.shared_quantizer: y = self.project_q(y) else: y = self.project_q_target(y) else: y = self.project_q(y) #print("target before masking", x.size()) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) #print("target after masking", x.size()) if self.target_glu: y = self.target_glu_target(y) negs = self.target_glu_target(negs) x = self.final_proj_target(x) #x = self.compute_preds(x, y, negs) #print("target before prediction", x.size()) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result #def forward(self, source, padding_mask=None, mask=True, features_only=True): # result = self.prediction(source[0], padding_mask, mask, features_only) # # return result["x"], result["padding_mask"] def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.step != 0: self.update_target() self.step += 1 result_0, mask_indices0 = self.predict( source[0], padding_mask[0] if padding_mask is not None else None, mask, features_only) if self.args.separate_mask_indices: result_1, mask_indices1 = self.predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, features_only, ) else: result_1, _ = self.predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, features_only, mask_indices=mask_indices0) mask_indices1 = mask_indices0 result_target_0 = self.target_predict( source[0], padding_mask[0] if padding_mask is not None else None, mask, mask_indices=mask_indices1) result_target_1 = self.target_predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, mask_indices=mask_indices0) return result_0, result_1, result_target_0, result_target_1 def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None self.feature_extractor_target = None self.post_extract_proj_target = None self.quantizer_target = None self.input_quantizer_target = None if not self.args.shared_emb: self.mask_emb_target = None self.encoder_target = None self.target_glu_target = None self.final_proj_target = None self.prediction = None