class Wav2VecEncoder(FairseqEncoder): def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu( 'models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.bart_encoder = bart.model.encoder bart_encoder = bart.model.encoder self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens) self.bart_encoder.load_state_dict(bart_encoder.state_dict()) self.fix_bart_encoder = cfg.fix_bart_encoder if self.fix_bart_encoder: print('fix bart encoder') for n, parameter in self.bart_encoder.named_parameters(): parameter.requires_grad = False if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None self.pad_token = cfg.pad_token self.mix_normalization_factor = cfg.mix_normalization_factor def set_num_updates(self, num_updates): """Set the number of parameters updates.""" super().set_num_updates(num_updates) self.num_updates = num_updates def forward(self, source, padding_mask, tbc=True, **kwargs): input_lengths = (1 - padding_mask.long()).sum(-1) output_length = torch.max( self.w2v_model._get_feat_extract_output_lengths(input_lengths)) # print('output_lengths', output_length, 'self.pad_token', self.pad_token) # print('kwargs', kwargs['bart_input_tokens'].shape, kwargs['bart_input_tokens'].type()) batch_size, ntoken = kwargs['bart_input_tokens'].shape bart_input = torch.zeros(batch_size, output_length).long().fill_( self.pad_token).to(kwargs['bart_input_tokens']) bart_input[:, :ntoken] = kwargs['bart_input_tokens'] # print(bart_input, bart_input.shape) # raise w2v_args = { "source": source, "padding_mask": padding_mask, "mask": self.apply_mask and self.training, } ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): x, padding_mask = self.w2v_model.extract_features(**w2v_args) if tbc: # B x T x C -> T x B x C x = x.transpose(0, 1) x = self.final_dropout(x) x_bart = self.bart_encoder(src_tokens=bart_input, src_lengths=None, token_embeddings=None, return_all_hiddens=False) if self.proj: x = self.proj(x) x_bart = x_bart['encoder_out'][0] # print('x.shape', x.shape, ) # print('x_bart', x_bart['encoder_out'][0].shape) # print(x_bart['encoder_padding_mask'][0].shape) prob = torch.sigmoid( torch.FloatTensor( [self.num_updates / self.mix_normalization_factor])) * 2 - 1 # n_mix = int(self.mix_rate * output_length) # indices = torch.randperm(output_length)[:n_mix] # print(n_mix, indices) # print(prob) # mask = torch.bernoulli(torch.full(x.shape, prob.item())).int().to(x) mask = torch.bernoulli(torch.full(x.shape[:1], prob.item()))[:, None, None].to(x) reverse_mask = 1 - mask x = x * mask + x_bart * reverse_mask # x_bart[indices,:,:] = x[indices,:,:] # print('self.num_updates', prob, self.num_updates) if self.num_updates % 1000 == 0: print('self.num_updates', prob, self.num_updates) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [padding_mask], # B x T } def reorder_encoder_out(self, encoder_out, new_order): if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] # T x B x C if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T } def max_positions(self): """Maximum input length supported by the encoder.""" return None def upgrade_state_dict_named(self, state_dict, name): return state_dict
class Wav2VecEncoder(FairseqEncoder): def __init__(self, cfg: Wav2Vec2BartConfig, tgt_dict=None, transform_embed=None, bart=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu('models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.bart_encoder = bart.model.encoder bart_encoder = bart.model.encoder self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens) self.bart_encoder.load_state_dict(bart_encoder.state_dict()) self.fix_bart_encoder = cfg.fix_bart_encoder if self.fix_bart_encoder: print('fix bart encoder') for n, parameter in self.bart_encoder.named_parameters(): parameter.requires_grad = False # if tgt_dict is not None: print('len(tgt_dict)', len(tgt_dict)) self.proj = Linear(d, len(tgt_dict)) # elif getattr(cfg, "decoder_embed_dim", d) != d: # self.proj = Linear(d, cfg.decoder_embed_dim) # else: # self.proj = None # bart.model.encoder.embed_tokens.weight.shape # here assume wav2vec and bart have same hidden size self.bart_encoder.embed_tokens.weight.requires_grad_(cfg.bart_embedding_finetune) self.transform_embed = transform_embed self.emb = EmbeddingTransformed(self.bart_encoder.embed_tokens, self.transform_embed) # if fix bart embedding self.pad_token = cfg.pad_token self.ctc_weight = cfg.ctc_weight self.ce_weight = cfg.ce_weight # self.mix_normalization_factor = cfg.mix_normalization_factor def set_num_updates(self, num_updates): """Set the number of parameters updates.""" super().set_num_updates(num_updates) self.num_updates = num_updates def forward(self, source, padding_mask, tbc=True, **kwargs): # -----------transform embedding----------- target_tokens = kwargs['target_tokens'] bart_emb = self.bart_encoder.embed_tokens.weight # transformed_emb = self.transform_embed(bart_emb.T).T # -----------wav2vec----------- w2v_args = { "source": source, "padding_mask": padding_mask, "mask": self.apply_mask and self.training, } # finetuning all without freeze ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): x, padding_mask = self.w2v_model.extract_features(**w2v_args) if tbc: # B x T x C -> T x B x C x = x.transpose(0, 1) x_wav2vec = self.final_dropout(x) # hidden embedding logits_wav2vec = self.proj(x) # T x B x V # -----------pad predict tokens----------- # if ft: logit_lengths = (1 - padding_mask.long()).sum(-1) # B x T logit_preds = torch.argmax(logits_wav2vec, dim=-1) # B if tbc: logit_preds = logit_preds.transpose(0, 1) # B x T print('logits_wav2vec.shape, logit_preds.shape', logits_wav2vec.shape, logit_preds.shape, logit_preds) pred_idxs, pred_lengths = [], [] for i, (y, length) in enumerate(zip(logit_preds, logit_lengths)): emb_idx = torch.stack([x[0] for x in groupby(y[:length])]) pred_idxs.append(emb_idx) pred_lengths.append(len(emb_idx)) max_len = max(pred_lengths) print('pred_lengths', pred_lengths, max_len) tokens_w2v = torch.zeros(len(logit_preds), max_len).long().fill_(self.pad_token) for i, pred_idx in enumerate(pred_idxs): tokens_w2v[i,:(len(pred_idx))] = pred_idx # use target_tokens if finetuning embbedding and transformation (not ft) # use tokens_w2v from wav2vec if fintuning if ft: # if finetune from prediction (after {freeze_finetune_updates} steps) bart_input = tokens_w2v bart_input_lengths = pred_lengths ctc_weight, ce_weight = self.ctc_weight, 1 else: # initial steps, from ground truth bart_input = target_tokens bart_input_lengths = kwargs['target_token_lengths'] ctc_weight, ce_weight = 1, 1 token_emb = self.emb(bart_input) # token_emb = torch.index_select(transformed_emb, 0, bart_input.reshape(-1)).view(*bart_input.shape, -1) # feed token to bart encoder bart_encoder_output = self.bart_encoder( src_tokens=bart_input, src_lengths=bart_input_lengths, token_embeddings=token_emb, # pass in customized embedding return_all_hiddens=False, ) # if self.num_updates % 1000 == 0: # print('self.num_updates', self.num_updates) return { "encoder_out": bart_encoder_output['encoder_out'], # T x B x C "encoder_padding_mask": bart_encoder_output['encoder_padding_mask'], # B x T "wav2vec_logits": logits_wav2vec, # T x B x C "wav2vec_padding_mask": padding_mask, "ctc_weight": ctc_weight, "ce_weight": ce_weight, } def reorder_encoder_out(self, encoder_out, new_order): if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] # T x B x C if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select(0, new_order) ] return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T } def max_positions(self): """Maximum input length supported by the encoder.""" return None def upgrade_state_dict_named(self, state_dict, name): return state_dict