Example #1
0
 def build_encoder(self, cfg, dictionary, embed_tokens):
     encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True)
     encoder.apply(init_bert_params)
     return encoder
Example #2
0
 def build_encoder(cls, args, src_dict, embed_tokens):
     encoder = TransformerEncoder(args, src_dict, embed_tokens)
     if getattr(args, "apply_bert_init", False):
         encoder.apply(init_bert_params)
     return encoder
Example #3
0
 def build_encoder(self, args, dictionary, embed_tokens):
     encoder = TransformerEncoder(args, dictionary, embed_tokens)
     encoder.apply(init_bert_params)
     return encoder
 def build_encoder(cls, args, src_dict, embed_tokens):
     return TransformerEncoder(args, src_dict, embed_tokens)
Example #5
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder
Example #6
0
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))
Example #7
0
    def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    'models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
                    w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args,
                                               bart_encoder.dictionary,
                                               bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(cfg, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, cfg.decoder_embed_dim)
        else:
            self.proj = None

        self.pad_token = cfg.pad_token
        self.mix_normalization_factor = cfg.mix_normalization_factor