Example #1
0
class Wav2VecEncoder(FairseqEncoder):
    def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    'models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
                    w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args,
                                               bart_encoder.dictionary,
                                               bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(cfg, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, cfg.decoder_embed_dim)
        else:
            self.proj = None

        self.pad_token = cfg.pad_token
        self.mix_normalization_factor = cfg.mix_normalization_factor

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, tbc=True, **kwargs):
        input_lengths = (1 - padding_mask.long()).sum(-1)
        output_length = torch.max(
            self.w2v_model._get_feat_extract_output_lengths(input_lengths))
        # print('output_lengths', output_length,  'self.pad_token', self.pad_token)
        # print('kwargs', kwargs['bart_input_tokens'].shape, kwargs['bart_input_tokens'].type())
        batch_size, ntoken = kwargs['bart_input_tokens'].shape
        bart_input = torch.zeros(batch_size, output_length).long().fill_(
            self.pad_token).to(kwargs['bart_input_tokens'])
        bart_input[:, :ntoken] = kwargs['bart_input_tokens']
        # print(bart_input, bart_input.shape)
        # raise
        w2v_args = {
            "source": source,
            "padding_mask": padding_mask,
            "mask": self.apply_mask and self.training,
        }

        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)

            if tbc:
                # B x T x C -> T x B x C
                x = x.transpose(0, 1)

        x = self.final_dropout(x)

        x_bart = self.bart_encoder(src_tokens=bart_input,
                                   src_lengths=None,
                                   token_embeddings=None,
                                   return_all_hiddens=False)

        if self.proj:
            x = self.proj(x)
        x_bart = x_bart['encoder_out'][0]
        # print('x.shape', x.shape, )
        # print('x_bart', x_bart['encoder_out'][0].shape)
        # print(x_bart['encoder_padding_mask'][0].shape)
        prob = torch.sigmoid(
            torch.FloatTensor(
                [self.num_updates / self.mix_normalization_factor])) * 2 - 1
        # n_mix = int(self.mix_rate * output_length)
        # indices = torch.randperm(output_length)[:n_mix]
        # print(n_mix, indices)
        # print(prob)
        # mask = torch.bernoulli(torch.full(x.shape, prob.item())).int().to(x)
        mask = torch.bernoulli(torch.full(x.shape[:1],
                                          prob.item()))[:, None, None].to(x)
        reverse_mask = 1 - mask
        x = x * mask + x_bart * reverse_mask
        # x_bart[indices,:,:] = x[indices,:,:]

        # print('self.num_updates', prob, self.num_updates)
        if self.num_updates % 1000 == 0:
            print('self.num_updates', prob, self.num_updates)

        return {
            "encoder_out": [x],  # T x B x C
            "encoder_padding_mask": [padding_mask],  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if len(encoder_out["encoder_out"]) == 0:
            new_encoder_out = []
        else:
            new_encoder_out = [
                encoder_out["encoder_out"][0].index_select(1, new_order)
            ]  # T x B x C

        if len(encoder_out["encoder_padding_mask"]) == 0:
            new_encoder_padding_mask = []
        else:
            new_encoder_padding_mask = [
                encoder_out["encoder_padding_mask"][0].index_select(
                    0, new_order)
            ]

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
        }

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict
Example #2
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder
Example #3
0
class Wav2VecEncoder(FairseqEncoder):
    def __init__(self, cfg: Wav2Vec2BartConfig, tgt_dict=None, transform_embed=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu('models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        # if tgt_dict is not None:
        print('len(tgt_dict)', len(tgt_dict))
        self.proj = Linear(d, len(tgt_dict))
        # elif getattr(cfg, "decoder_embed_dim", d) != d:
        # self.proj = Linear(d, cfg.decoder_embed_dim)
        # else:
        #     self.proj = None

        # bart.model.encoder.embed_tokens.weight.shape
        # here assume wav2vec and bart have same hidden size
        self.bart_encoder.embed_tokens.weight.requires_grad_(cfg.bart_embedding_finetune)
        self.transform_embed = transform_embed
        self.emb = EmbeddingTransformed(self.bart_encoder.embed_tokens, self.transform_embed)
        # if fix bart embedding 

        self.pad_token = cfg.pad_token
        self.ctc_weight = cfg.ctc_weight
        self.ce_weight = cfg.ce_weight
        
        # self.mix_normalization_factor = cfg.mix_normalization_factor

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, tbc=True, **kwargs):
        # -----------transform embedding-----------
        target_tokens = kwargs['target_tokens']
        bart_emb = self.bart_encoder.embed_tokens.weight
        # transformed_emb = self.transform_embed(bart_emb.T).T

        # -----------wav2vec-----------
        w2v_args = {
            "source": source,
            "padding_mask": padding_mask,
            "mask": self.apply_mask and self.training,
        }

        # finetuning all without freeze
        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)
            if tbc:
                # B x T x C -> T x B x C
                x = x.transpose(0, 1)

        x_wav2vec = self.final_dropout(x) # hidden embedding
        logits_wav2vec = self.proj(x) # T x B x V
        
        # -----------pad predict tokens-----------
        # if ft:

        logit_lengths = (1 - padding_mask.long()).sum(-1) # B x T
        logit_preds = torch.argmax(logits_wav2vec, dim=-1) # B
        
        if tbc:
            logit_preds = logit_preds.transpose(0, 1) # B x T

        print('logits_wav2vec.shape, logit_preds.shape', logits_wav2vec.shape, logit_preds.shape, logit_preds)
        pred_idxs, pred_lengths = [], []
        for i, (y, length) in enumerate(zip(logit_preds, logit_lengths)):
            emb_idx = torch.stack([x[0] for x in groupby(y[:length])])
            pred_idxs.append(emb_idx)
            pred_lengths.append(len(emb_idx))
        
        max_len = max(pred_lengths)
        print('pred_lengths', pred_lengths, max_len)
        tokens_w2v = torch.zeros(len(logit_preds), max_len).long().fill_(self.pad_token)

        for i, pred_idx in enumerate(pred_idxs):
            tokens_w2v[i,:(len(pred_idx))] = pred_idx

        # use target_tokens if finetuning embbedding and transformation (not ft)
        # use tokens_w2v from wav2vec if fintuning
        if ft: # if finetune from prediction (after {freeze_finetune_updates} steps)
            bart_input = tokens_w2v
            bart_input_lengths = pred_lengths
            ctc_weight, ce_weight = self.ctc_weight, 1
        else: # initial steps, from ground truth
            bart_input = target_tokens
            bart_input_lengths = kwargs['target_token_lengths']
            ctc_weight, ce_weight = 1, 1
        token_emb = self.emb(bart_input)
        # token_emb = torch.index_select(transformed_emb, 0, bart_input.reshape(-1)).view(*bart_input.shape, -1)


        # feed token to bart encoder
        bart_encoder_output = self.bart_encoder(
            src_tokens=bart_input,
            src_lengths=bart_input_lengths,
            token_embeddings=token_emb, # pass in customized embedding
            return_all_hiddens=False,
        )

        # if self.num_updates % 1000 == 0:
        #     print('self.num_updates', self.num_updates)

        return {
            "encoder_out": bart_encoder_output['encoder_out'],  # T x B x C
            "encoder_padding_mask": bart_encoder_output['encoder_padding_mask'],  # B x T
            "wav2vec_logits": logits_wav2vec,  # T x B x C
            "wav2vec_padding_mask": padding_mask,
            "ctc_weight": ctc_weight,
            "ce_weight": ce_weight,
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if len(encoder_out["encoder_out"]) == 0:
            new_encoder_out = []
        else:
            new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] # T x B x C

        if len(encoder_out["encoder_padding_mask"]) == 0:
            new_encoder_padding_mask = []
        else:
            new_encoder_padding_mask = [
                encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
            ]

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
        }

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict