class Wav2Vec2Model(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2Config): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2Config, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
def __init__(self, cfg: Wav2Vec2Config): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
class Wav2Vec2Model(BaseFairseqModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--extractor-mode", choices=["default", "layer_norm"], help= "mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", ) parser.add_argument( "--encoder-layers", type=int, metavar="L", help="num encoder layers in the transformer", ) parser.add_argument( "--encoder-embed-dim", type=int, metavar="H", help="encoder embedding dimension", ) parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="F", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="A", help="num encoder attention heads", ) parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability for the transformer", ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights", ) parser.add_argument( "--activation-dropout", type=float, metavar="D", help="dropout probability after activation in FFN", ) parser.add_argument( "--final-dim", type=int, metavar="D", help= "project final representations and targets to this many dimensions", ) parser.add_argument( "--layer-norm-first", action="store_true", help="apply layernorm first in the transformer", ) parser.add_argument( "--encoder-layerdrop", type=float, help="probability of dropping a tarnsformer layer", ) parser.add_argument( "--conv-feature-layers", type=str, metavar="EXPR", help= "convolutional feature extraction layers [(dim, kernel_size, stride), ...]", ) parser.add_argument("--logit-temp", type=float, help="temperature to divide logits by") parser.add_argument("--quantize-targets", action="store_true", help="use quantized targets") parser.add_argument("--quantize-input", action="store_true", help="use quantized inputs") parser.add_argument( "--feature-grad-mult", type=float, help="multiply feature extractor var grads by this", ) parser.add_argument( "--latent-vars", type=int, metavar="N", help="number of latent variables V in each group of the codebook", ) parser.add_argument( "--latent-groups", type=int, metavar="N", help="number of groups G of latent variables in the codebook", ) parser.add_argument( "--latent-dim", type=int, metavar="N", help= "if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", ) parser.add_argument("--mask-length", type=int, help="mask length") parser.add_argument("--mask-prob", type=float, help="probability of replacing a token with mask") parser.add_argument( "--mask-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--mask-channel-length", type=int, help="repeat the mask indices multiple times", ) parser.add_argument( "--mask-channel-prob", type=float, help="probability of replacing a token with mask", ) parser.add_argument( "--mask-channel-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-channel-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-channel-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-channel-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--dropout-input", type=float, metavar="D", help="dropout to apply to the input (after feat extr)", ) parser.add_argument( "--dropout-features", type=float, metavar="D", help="dropout to apply to the features (after feat extr)", ) parser.add_argument("--num-negatives", type=int, metavar="N", help="number of negative examples") parser.add_argument( "--negatives-from-everywhere", action="store_true", help="sample negatives from everywhere, not just masked states", ) parser.add_argument( "--cross-sample-negatives", type=int, metavar="N", help="num of cross sampled negatives", ) parser.add_argument( "--codebook-negatives", type=int, metavar="N", help="num of codebook sampled negatives", ) parser.add_argument( "--conv-pos", type=int, metavar="N", help="number of filters for convolutional positional embeddings", ) parser.add_argument( "--conv-pos-groups", type=int, metavar="N", help="number of groups for convolutional positional embedding", ) parser.add_argument( "--latent-temp", type=str, metavar="D", help= "temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", ) parser.add_argument("--target-glu", action="store_true", help="adds projection + glu to targets") parser.add_argument("--conv-bias", action="store_true", help="include bias in conv encoder") def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim and not args.quantize_input else None) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp if args.quantize_input: vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim self.input_quantizer = (GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.same_quantizer else self.quantizer) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
def __init__(self, cfg: Wav2VecConfig): super().__init__() self.prediction_steps = cfg.prediction_steps offset = cfg.offset if cfg.activation == "relu": activation = nn.ReLU() elif cfg.activation == "gelu": activation = nn.GELU() else: raise Exception("unknown activation " + cfg.activation) feature_enc_layers = eval(cfg.conv_feature_layers) self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, log_compression=cfg.log_compression, skip_connections=cfg.skip_connections_feat, residual_scale=cfg.residual_scale, non_affine_group_norm=cfg.non_affine_group_norm, activation=activation, ) embed = feature_enc_layers[-1][0] self.vector_quantizer = None if cfg.vq_type == "gumbel": self.vector_quantizer = GumbelVectorQuantizer( dim=embed, num_vars=cfg.vq_vars, temp=cfg.vq_temp, groups=cfg.vq_groups, combine_groups=cfg.combine_groups, vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, activation=activation, weight_proj_depth=cfg.vq_depth, weight_proj_factor=2, ) elif cfg.vq_type == "kmeans": self.vector_quantizer = KmeansVectorQuantizer( dim=embed, num_vars=cfg.vq_vars, groups=cfg.vq_groups, combine_groups=cfg.combine_groups, vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, gamma=cfg.vq_gamma, ) else: assert ( cfg.vq_type == "none" or cfg.vq_type is None ), "Unknown quantizer type" if cfg.offset == "auto": jin = 0 rin = 0 for _, k, stride in feature_enc_layers: if rin == 0: rin = k rin = rin + (k - 1) * jin if jin == 0: jin = stride else: jin *= stride offset = math.ceil(rin / jin) offset = int(offset) def make_aggregator(): if cfg.aggregator == "cnn": agg_layers = eval(cfg.conv_aggregator_layers) agg_dim = agg_layers[-1][0] feature_aggregator = ConvAggegator( conv_layers=agg_layers, embed=embed, dropout=cfg.dropout, skip_connections=cfg.skip_connections_agg, residual_scale=cfg.residual_scale, non_affine_group_norm=cfg.non_affine_group_norm, conv_bias=not cfg.no_conv_bias, zero_pad=cfg.agg_zero_pad, activation=activation, ) elif cfg.aggregator == "gru": agg_dim = cfg.gru_dim feature_aggregator = nn.Sequential( TransposeLast(), nn.GRU( input_size=embed, hidden_size=agg_dim, num_layers=1, dropout=cfg.dropout, ), TransposeLast(deconstruct_idx=0), ) else: raise Exception("unknown aggregator type " + cfg.aggregator) return feature_aggregator, agg_dim self.feature_aggregator, agg_dim = make_aggregator() self.wav2vec_predictions = Wav2VecPredictionsModel( in_dim=agg_dim, out_dim=embed, prediction_steps=cfg.prediction_steps, n_negatives=cfg.num_negatives, cross_sample_negatives=cfg.cross_sample_negatives, sample_distance=cfg.sample_distance, dropout=cfg.dropout, offset=offset, balanced_classes=cfg.balanced_classes, infonce=cfg.infonce, ) self.dropout_feats = nn.Dropout(p=cfg.dropout_features) self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) if cfg.project_features == "none": self.project_features = None elif cfg.project_features == "same": self.project_features = self.feature_aggregator elif cfg.project_features == "new": self.project_features, _ = make_aggregator()
class Wav2Vec2Model(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2Config): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_before = cfg.mask_channel_before self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = cfg.num_negatives self.cross_sample_negatives = cfg.cross_sample_negatives self.codebook_negatives = cfg.codebook_negatives self.negatives_from_everywhere = cfg.negatives_from_everywhere self.logit_temp = cfg.logit_temp final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim if cfg.quantize_targets: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2Config, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask( self, x, padding_mask, mask_indices=None, mask_channel_indices=None, ): B, T, C = x.shape if self.mask_channel_prob > 0 and self.mask_channel_before: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 if self.mask_prob > 0: if mask_indices is None: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0 and not self.mask_channel_before: if mask_channel_indices is None: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x = index_put(x, mask_channel_indices, 0) return x, mask_indices def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.n_negatives).flatten()) neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = (buffered_arange(num).unsqueeze(-1).expand( -1, self.cross_sample_negatives).flatten()) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): neg_is_pos = (y == negatives).all(-1) y = y.unsqueeze(0) targets = torch.cat([y, negatives], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits = logits / self.logit_temp if is_xla_tensor(logits) or neg_is_pos.any(): fillval = -float(2**30) if not hasattr(self, "_inftensor"): self._inftensor = (torch.tensor(fillval).to(x.device) if is_xla_tensor(logits) else float("-inf")) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): return torch.floor((input_length - kernel_size) / stride + 1) conv_cfg_list = eval(self.cfg.conv_feature_layers) for i in range(len(conv_cfg_list)): input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) return input_lengths.to(torch.long) def forward( self, source, padding_mask=None, mask=True, features_only=False, layer=None, mask_indices=None, mask_channel_indices=None, padding_count=None, ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None and padding_mask.any(): input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = self._get_feat_extract_output_lengths( input_lengths) padding_mask = torch.zeros(features.shape[:2], dtype=features.dtype, device=features.device) # these two operations makes sure that all values # before the output lengths indices are attended to padding_mask[( torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1, )] = 1 padding_mask = ( 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() else: padding_mask = None if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) if not is_xla_tensor(x) and mask_indices is not None: # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) if features_only: return { "x": x, "padding_mask": padding_mask, "features": unmasked_features, "layer_results": layer_results, } if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands = self.quantizer(unmasked_features, produce_targets=False)["x"] negs, _ = self.sample_negatives( neg_cands, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives( unmasked_features, y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) if not is_xla_tensor(x): # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen, } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False, layer=None): res = self.forward(source, padding_mask, mask=mask, features_only=True, layer=layer) return res def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim and not args.quantize_input else None) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim # 256 self.quantizer = GumbelVectorQuantizer( dim=self.embed, # 512 num_vars=args.latent_vars, # 320 temp=eval(args.latent_temp), # (2,0.5,0.999995) groups=args.latent_groups, # 2 combine_groups=False, vq_dim=vq_dim, # 256 time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if args.quantize_input: if args.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = (args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim) self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) if getattr(args, "w2v_path", None): print('load Wav2VecEncoder from {}'.format(args.w2v_path)) state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) for i in list(state['model'].keys()): if 'quantizer' in i: state['model'].pop(i) print(self.load_state_dict(state["model"], strict=False))
def __init__(self, args): super().__init__() self.args = args self.post_extract_proj = nn.Linear(80, args.encoder_embed_dim) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp if args.quantize_input: vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim self.input_quantizer = ( GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.same_quantizer else self.quantizer ) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU() ) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.post_extract_proj = ( nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim and not args.quantize_input else None ) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if args.quantize_input: if args.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: vq_dim = ( args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim ) self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU() ) cp = torch.load(args.w2v_checkpoint, map_location=lambda x, _: x) self.load_state_dict(cp["model"], strict=False) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
class DeCoAr2Model(BaseFairseqModel): def __init__(self, args): super().__init__() self.args = args self.post_extract_proj = nn.Linear(80, args.encoder_embed_dim) self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.n_negatives = args.num_negatives self.cross_sample_negatives = args.cross_sample_negatives self.codebook_negatives = args.codebook_negatives self.negatives_from_everywhere = args.negatives_from_everywhere self.logit_temp = args.logit_temp if args.quantize_input: vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim self.input_quantizer = (GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.same_quantizer else self.quantizer) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=args.encoder_embed_dim, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) self.encoder = TransformerEncoder(args) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.post_extract_proj(source) unmasked_features = features.clone() features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) if self.negatives_from_everywhere: neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) negs, _ = self.sample_negatives(neg_cands, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives) cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1) # order doesnt matter cb_negs = self.project_q(cb_negs) negs = torch.cat([negs, cb_negs], dim=0) else: y = self.project_q(y) if self.negatives_from_everywhere: negs, _ = self.sample_negatives(unmasked_features, y.size(1)) negs = self.project_q(negs) else: negs, _ = self.sample_negatives(y, y.size(1)) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.forward(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) return logits def get_targets(self, sample, net_output, expand_steps=True): x = net_output["x"] return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): pen = [] if "prob_perplexity" in net_output: pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"]) if "features_pen" in net_output: pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None
def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.online_params = [] self.target_params = [] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.feature_extractor_target = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.online_params += list(self.feature_extractor.parameters()) self.target_params += list(self.feature_extractor_target.parameters()) if self.embed != args.encoder_embed_dim and not args.quantize_input: self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim)) self.post_extract_proj_target = (nn.Linear(self.embed, args.encoder_embed_dim)) self.online_params += list(self.post_extract_proj.parameters()) self.target_params += list( self.post_extract_proj_target.parameters()) else: self.post_extract_proj = None self.post_extract_proj_target = None self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.base_decay = args.base_decay self.step = 0 self.total_steps = args.total_num_update self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.shared_quantizer = args.shared_quantizer final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list(self.quantizer.parameters()) self.target_params += list(self.quantizer_target.parameters()) ### TODO: separate project_q? self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if args.quantize_input: if args.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer if not args.shared_quantizer: self.input_quantizer_target = self.quantizer_target else: vq_dim = (args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim) self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.input_quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list( self.input_quantizer.parameters()) self.target_params += list( self.input_quantizer_target.parameters()) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) if not args.shared_emb: self.mask_emb_target = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) #self.online_params += list(self.mask_emb) #self.target_params += list(self.mask_emb_target) if args.mlp_encoder: self.encoder = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) self.encoder_target = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) else: self.encoder = TransformerEncoder(args) self.encoder_target = TransformerEncoder(args) self.online_params += list(self.encoder.parameters()) self.target_params += list(self.encoder_target.parameters()) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.target_glu_target = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.online_params += list(self.target_glu.parameters()) self.target_params += list(self.target_glu_target.parameters()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) self.final_proj_target = nn.Linear(args.encoder_embed_dim, final_dim) self.online_params += list(self.final_proj.parameters()) self.target_params += list(self.final_proj_target.parameters()) ### TODO: Transformer prediction network? if args.mlp_prediction: self.prediction = MLPPrediction(final_dim, args.byol_hidden_dim) else: self.prediction = TransformerEncoder(args, args.final_dim) for param in self.target_params: param.requires_grad = False
class ABCModel(BaseFairseqModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--extractor-mode", choices=["default", "layer_norm"], help= "mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", ) parser.add_argument( "--encoder-layers", type=int, metavar="L", help="num encoder layers in the transformer", ) parser.add_argument( "--encoder-embed-dim", type=int, metavar="H", help="encoder embedding dimension", ) parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="F", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="A", help="num encoder attention heads", ) parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability for the transformer", ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights", ) parser.add_argument( "--activation-dropout", type=float, metavar="D", help="dropout probability after activation in FFN", ) parser.add_argument( "--final-dim", type=int, metavar="D", help= "project final representations and targets to this many dimensions", ) parser.add_argument( "--layer-norm-first", action="store_true", help="apply layernorm first in the transformer", ) parser.add_argument( "--encoder-layerdrop", type=float, help="probability of dropping a tarnsformer layer", ) parser.add_argument( "--conv-feature-layers", type=str, metavar="EXPR", help= "convolutional feature extraction layers [(dim, kernel_size, stride), ...]", ) parser.add_argument("--logit-temp", type=float, help="temperature to divide logits by") parser.add_argument("--quantize-targets", action="store_true", help="use quantized targets") parser.add_argument("--quantize-input", action="store_true", help="use quantized inputs") parser.add_argument( "--same-quantizer", action="store_true", help="use same quantizer for inputs and targets", ) parser.add_argument( "--feature-grad-mult", type=float, help="multiply feature extractor var grads by this", ) parser.add_argument( "--latent-vars", type=int, metavar="N", help="number of latent variables V in each group of the codebook", ) parser.add_argument( "--latent-groups", type=int, metavar="N", help="number of groups G of latent variables in the codebook", ) parser.add_argument( "--latent-dim", type=int, metavar="N", help= "if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", ) parser.add_argument("--mask-length", type=int, help="mask length") parser.add_argument("--mask-prob", type=float, help="probability of replacing a token with mask") parser.add_argument( "--mask-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--mask-channel-length", type=int, help="repeat the mask indices multiple times", ) parser.add_argument( "--mask-channel-prob", type=float, help="probability of replacing a token with mask", ) parser.add_argument( "--mask-channel-selection", type=str, choices=["static", "uniform", "normal", "poisson"], help="how to choose masks", ) parser.add_argument( "--mask-channel-other", type=float, help= "secondary mask argument (used for more complex distributions), see help in compute_mask_indices", ) parser.add_argument( "--no-mask-channel-overlap", action="store_true", help="whether to allow masks to overlap", ) parser.add_argument( "--mask-channel-min-space", type=int, help="min space between spans (if no overlap is enabled)", ) parser.add_argument( "--dropout-input", type=float, metavar="D", help="dropout to apply to the input (after feat extr)", ) parser.add_argument( "--dropout-features", type=float, metavar="D", help="dropout to apply to the features (after feat extr)", ) parser.add_argument("--num-negatives", type=int, metavar="N", help="number of negative examples") parser.add_argument( "--negatives-from-everywhere", action="store_true", help="sample negatives from everywhere, not just masked states", ) parser.add_argument( "--cross-sample-negatives", type=int, metavar="N", help="num of cross sampled negatives", ) parser.add_argument( "--codebook-negatives", type=int, metavar="N", help="num of codebook sampled negatives", ) #parser.add_argument( # "--byol-all", # action="store_true", # help="apply byol to whole network" #) parser.add_argument("--base-decay", type=float, metavar="D", help="base decay value of target network") parser.add_argument("--projection-dim", type=int, metavar="N", help="byol projection network's output dimension") parser.add_argument("--prediction-dim", type=int, metavar="N", help="byol prediction network's output dimension") parser.add_argument("--byol-hidden-dim", type=int, metavar="N", help="hidden dimension of MLPs used for byol") parser.add_argument("--shared-quantizer", action="store_true", help="share quantizer with target network") parser.add_argument("--shared-emb", action="store_true", help="share mask embedding parameter") parser.add_argument("--mlp-prediction", action="store_true", help="use MLP as prediction network") parser.add_argument("--mlp-encoder", action="store_true", help="use MLP as encoder network") parser.add_argument( "--separate-mask-indices", action="store_true", help="use two separate mask indices for Transformers") parser.add_argument( "--conv-pos", type=int, metavar="N", help="number of filters for convolutional positional embeddings", ) parser.add_argument( "--conv-pos-groups", type=int, metavar="N", help="number of groups for convolutional positional embedding", ) parser.add_argument( "--latent-temp", type=str, metavar="D", help= "temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", ) parser.add_argument("--target-glu", action="store_true", help="adds projection + glu to targets") parser.add_argument("--conv-bias", action="store_true", help="include bias in conv encoder") def __init__(self, args): super().__init__() self.args = args feature_enc_layers = eval(args.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.online_params = [] self.target_params = [] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.feature_extractor_target = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) self.online_params += list(self.feature_extractor.parameters()) self.target_params += list(self.feature_extractor_target.parameters()) if self.embed != args.encoder_embed_dim and not args.quantize_input: self.post_extract_proj = (nn.Linear(self.embed, args.encoder_embed_dim)) self.post_extract_proj_target = (nn.Linear(self.embed, args.encoder_embed_dim)) self.online_params += list(self.post_extract_proj.parameters()) self.target_params += list( self.post_extract_proj_target.parameters()) else: self.post_extract_proj = None self.post_extract_proj_target = None self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.mask_length = args.mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.dropout_input = nn.Dropout(args.dropout_input) self.dropout_features = nn.Dropout(args.dropout_features) self.base_decay = args.base_decay self.step = 0 self.total_steps = args.total_num_update self.feature_grad_mult = args.feature_grad_mult self.quantizer = None self.input_quantizer = None self.shared_quantizer = args.shared_quantizer final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list(self.quantizer.parameters()) self.target_params += list(self.quantizer_target.parameters()) ### TODO: separate project_q? self.project_q = nn.Linear(vq_dim, final_dim) else: self.project_q = nn.Linear(self.embed, final_dim) if args.quantize_input: if args.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer if not args.shared_quantizer: self.input_quantizer_target = self.quantizer_target else: vq_dim = (args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim) self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) if not args.shared_quantizer: self.input_quantizer_target = GumbelVectorQuantizer( dim=self.embed, num_vars=args.latent_vars, temp=eval(args.latent_temp), groups=args.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) self.online_params += list( self.input_quantizer.parameters()) self.target_params += list( self.input_quantizer_target.parameters()) self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) if not args.shared_emb: self.mask_emb_target = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_()) #self.online_params += list(self.mask_emb) #self.target_params += list(self.mask_emb_target) if args.mlp_encoder: self.encoder = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) self.encoder_target = nn.Sequential( nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim), nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True), nn.Linaer(args.byol_hidden_dim, final_dim)) else: self.encoder = TransformerEncoder(args) self.encoder_target = TransformerEncoder(args) self.online_params += list(self.encoder.parameters()) self.target_params += list(self.encoder_target.parameters()) self.layer_norm = LayerNorm(self.embed) self.target_glu = None if args.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.target_glu_target = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU()) self.online_params += list(self.target_glu.parameters()) self.target_params += list(self.target_glu_target.parameters()) self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) self.final_proj_target = nn.Linear(args.encoder_embed_dim, final_dim) self.online_params += list(self.final_proj.parameters()) self.target_params += list(self.final_proj_target.parameters()) ### TODO: Transformer prediction network? if args.mlp_prediction: self.prediction = MLPPrediction(final_dim, args.byol_hidden_dim) else: self.prediction = TransformerEncoder(args, args.final_dim) for param in self.target_params: param.requires_grad = False #def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: # gen = self._named_members( # lambda module: module._parameters.items(), # prefix=prefix, recurse=recurse) # for elem in gen: # if elem[1] not in self.target_params: # yield elem def update_target(self): decay = 1 - (1 - self.base_decay) * ( np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 self.step += 1 for online_param, target_param in zip(self.online_params, self.target_params): target_old = target_param.data target_param.data = decay * target_old + ( 1 - decay) * online_param.data def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, args, task=None): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) return cls(args) def apply_mask(self, x, padding_mask, mask_indices=None): B, T, C = x.shape if mask_indices is None: if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None else: x[mask_indices] = self.mask_emb if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = (torch.from_numpy(mask_channel_indices).to( x.device).unsqueeze(1).expand(-1, T, -1)) x[mask_channel_indices] = 0 return x, mask_indices def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) x = x.transpose(1, 2) x = self.layer_norm(x) return self.quantizer.forward_idx(x) def extract_features(self, source, padding_mask, mask=False): res = self.predict(source, padding_mask, mask=mask, features_only=True) return res["x"], res["padding_mask"] def predict(self, source, padding_mask=None, mask=True, features_only=False, mask_indices=None): if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj is not None: features = self.post_extract_proj(features) #print("online feature", features.size()) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] features = self.project_inp(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask, mask_indices) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder(x, padding_mask=padding_mask) #print("online encoded", x.size()) if features_only: return {"x": x, "padding_mask": padding_mask} if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] y = self.project_q(y) else: y = self.project_q(y) #print("online before masking", x.size()) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) #print("online after masking", x.size()) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) #x = self.compute_preds(x, y, negs) #print("online before prediction", x.size()) if self.args.mlp_prediction: x = self.prediction(x) else: x = self.prediction(x, padding_mask=padding_mask) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result, mask_indices def target_predict(self, source, padding_mask=None, mask=True, mask_indices=None): with torch.no_grad(): if self.feature_grad_mult > 0: features = self.feature_extractor_target(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor_target(source) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) if self.post_extract_proj_target is not None: features = self.post_extract_proj_target(features) #print("target feature", features.size()) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) num_vars = None code_ppl = None prob_ppl = None curr_temp = None if self.input_quantizer: if self.shared_quantizer: q = self.input_quantizer(features, produce_targets=False) else: q = self.input_quantizer(features, produce_targets=False) features = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] if self.shared_quantizer: features = self.project_inp(features) else: features = self.project_inp_target(features) if mask: x, mask_indices = self.apply_mask(features, padding_mask, mask_indices) if mask_indices is not None: y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1)) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None x = self.encoder_target(x, padding_mask=padding_mask) #print("target encoded", x.size()) if self.quantizer: if self.shared_quantizer: q = self.quantizer(y, produce_targets=False) else: q = self.quantizer_target(y, produce_targets=False) y = q["x"] num_vars = q["num_vars"] code_ppl = q["code_perplexity"] prob_ppl = q["prob_perplexity"] curr_temp = q["temp"] if self.shared_quantizer: y = self.project_q(y) else: y = self.project_q_target(y) else: y = self.project_q(y) #print("target before masking", x.size()) x = x[mask_indices].view(x.size(0), -1, x.size(-1)) #print("target after masking", x.size()) if self.target_glu: y = self.target_glu_target(y) negs = self.target_glu_target(negs) x = self.final_proj_target(x) #x = self.compute_preds(x, y, negs) #print("target before prediction", x.size()) result = { "x": x, "padding_mask": padding_mask, "features_pen": features_pen } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp return result #def forward(self, source, padding_mask=None, mask=True, features_only=True): # result = self.prediction(source[0], padding_mask, mask, features_only) # # return result["x"], result["padding_mask"] def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.step != 0: self.update_target() self.step += 1 result_0, mask_indices0 = self.predict( source[0], padding_mask[0] if padding_mask is not None else None, mask, features_only) if self.args.separate_mask_indices: result_1, mask_indices1 = self.predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, features_only, ) else: result_1, _ = self.predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, features_only, mask_indices=mask_indices0) mask_indices1 = mask_indices0 result_target_0 = self.target_predict( source[0], padding_mask[0] if padding_mask is not None else None, mask, mask_indices=mask_indices1) result_target_1 = self.target_predict( source[1], padding_mask[1] if padding_mask is not None else None, mask, mask_indices=mask_indices0) return result_0, result_1, result_target_0, result_target_1 def remove_pretraining_modules(self): self.quantizer = None self.project_q = None self.target_glu = None self.final_proj = None self.feature_extractor_target = None self.post_extract_proj_target = None self.quantizer_target = None self.input_quantizer_target = None if not self.args.shared_emb: self.mask_emb_target = None self.encoder_target = None self.target_glu_target = None self.final_proj_target = None self.prediction = None