def build_encoder(cls,
                      args,
                      src_dict,
                      embed_tokens,
                      src_factor_embed_tokens=None):
        if src_factor_embed_tokens:
            src_encoder = SrcFactorEncoder(args, src_dict, embed_tokens,
                                           src_factor_embed_tokens)
        else:
            src_encoder = TransformerEncoder(args, src_dict, embed_tokens)
        if getattr(args, "apply_bert_init", False):
            src_encoder.apply(init_bert_params)

        if getattr(args, "share_encoder", False):
            mt_encoder = src_encoder
        else:
            if src_factor_embed_tokens:
                mt_encoder = SrcFactorEncoder(args, src_dict, embed_tokens,
                                              src_factor_embed_tokens)
            else:
                mt_encoder = TransformerEncoder(args, src_dict, embed_tokens)
            if getattr(args, "apply_bert_init", False):
                mt_encoder.apply(init_bert_params)
        encoder = MultisourceEncoder(src_encoder, {'mt': mt_encoder},
                                     order=['mt'])
        return encoder
Beispiel #2
0
 def build_encoder(self, cfg, dictionary, embed_tokens):
     encoder = TransformerEncoder(cfg.transformer,
                                  dictionary,
                                  embed_tokens,
                                  return_fc=True)
     encoder.apply(init_bert_params)
     return encoder
Beispiel #3
0
 def build_text_encoder(cls, args, src_dictionary, spch_encoder):
     if args.encoder_shared_layers > 0:
         mx_shared_layers = (
             args.speech_encoder_layers
             if args.speech_encoder_layers < args.text_encoder_layers else
             args.text_encoder_layers)
         args.encoder_shared_layers = (
             args.encoder_shared_layers
             if args.encoder_shared_layers <= mx_shared_layers else
             mx_shared_layers)
     cfg = {
         "encoder_embed_dim": args.encoder_text_embed_dim,
         "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
         "encoder_layers": args.text_encoder_layers,
         "encoder_layerdrop": args.encoder_layerdrop,
         "encoder_attention_heads": args.encoder_attention_heads,
         "encoder_learned_pos": args.encoder_learned_pos,
         "max_source_positions": args.max_source_positions,
         "dropout": args.dropout,
         "encoder_normalize_before": args.encoder_normalize_before,
         "activation_dropout": args.activation_dropout,
         "attention_dropout": args.attention_dropout,
         "activation_fn": args.activation_fn,
         "adaptive_input": args.adaptive_input,
         "no_token_positional_embeddings":
         args.no_token_positional_embeddings,
         "no_scale_embedding": args.no_scale_embedding,
         "quant_noise_pq": args.quant_noise_pq,
     }
     model_args = namedtuple("args", cfg.keys())(*cfg.values())
     enc_emb = nn.Embedding(len(src_dictionary),
                            model_args.encoder_embed_dim,
                            src_dictionary.pad())
     text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
     if args.add_speech_eos:
         spch_encoder = spch_encoder.encoder
     if args.encoder_shared_layers > 0:
         text_encoder.layer_norm = cls.set_shared_layer(
             args.encoder_shared_layer_level,
             text_encoder.layer_norm,
             spch_encoder.layer_norm,
         )
         for i, ly in enumerate(
                 spch_encoder.
                 transformer_layers[-args.encoder_shared_layers:]):
             ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
             if not isinstance(text_encoder.layers[ly_id], type(ly)):
                 if text_encoder.layers[ly_id]._get_name() not in (
                         'TransformerEncoderLayerBase',
                         'TransformerEncoderLayer'):
                     raise ValueError(
                         "The shared layers are expected from the same class"
                     )
             text_encoder.layers[ly_id] = cls.set_shared_layer(
                 args.encoder_shared_layer_level,
                 text_encoder.layers[ly_id],
                 ly,
             )
     return text_encoder
Beispiel #4
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise RuntimeError(
                    '--share-all-embeddings requires a joined dictionary')
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise RuntimeError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise RuntimeError(
                    '--share-all-embeddings not compatible with --decoder-embed-path'
                )
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = build_embedding(tgt_dict,
                                                   args.decoder_embed_dim,
                                                   args.decoder_embed_path)

        encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
        decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
        encoder2 = TransformerEncoder(args, tgt_dict, decoder_embed_token)
        decoder2 = TransformerDecoder(args, src_dict, encoder_embed_tokens)
        return TransformerDualModel(encoder, decoder, encoder2, decoder2)
Beispiel #5
0
 def build_text_encoder(cls, args, src_dictionary):
     enc_emb = nn.Embedding(len(src_dictionary), args.encoder_embed_dim,
                            src_dictionary.pad())
     model_args = cls.update_transformer_encoder_cfg(
         args, {"encoder_layers": args.text_encoder_layers})
     text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
     return text_encoder
 def build_encoder(cls, args, src_dict, embed_tokens):
     if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
         return HeadSelectionTransformerEncoder(
             args, src_dict, embed_tokens
         )
     else:
         return TransformerEncoder(args, src_dict, embed_tokens)
Beispiel #7
0
 def get_encoder(lang):
     if lang not in lang_encoders:
         if shared_encoder_embed_tokens is not None:
             encoder_embed_tokens = shared_encoder_embed_tokens
         else:
             encoder_embed_tokens = build_embedding(
                 task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
             )
         lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
     return lang_encoders[lang]
Beispiel #8
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_lm_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))

        if getattr(args, 'max_target_positions', None) is None:
            args.max_target_positions = getattr(args, 'tokens_per_sample',
                                                DEFAULT_MAX_TARGET_POSITIONS)

        if args.character_embeddings:
            embed_tokens = CharacterTokenEmbedder(
                task.source_dictionary,
                eval(args.character_filters),
                args.character_embedding_dim,
                args.encoder_embed_dim,
                args.char_embedder_highway_layers,
            )
        elif args.adaptive_input:
            embed_tokens = AdaptiveInput(
                len(task.source_dictionary),
                task.source_dictionary.pad(),
                args.encoder_input_dim,
                args.adaptive_input_factor,
                args.encoder_embed_dim,
                options.eval_str_list(args.adaptive_input_cutoff, type=int),
                args.quant_noise_pq,
                args.quant_noise_pq_block_size,
            )
        else:
            embed_tokens = cls.build_embedding(args, task.source_dictionary,
                                               args.encoder_input_dim)

        if args.tie_adaptive_weights:
            assert args.adaptive_input
            assert args.adaptive_input_factor == args.adaptive_softmax_factor
            assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format(
                args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
            assert args.encoder_input_dim == args.encoder_output_dim

        encoder = TransformerEncoder(
            args,
            task.target_dictionary,
            embed_tokens,
        )
        print('Encoder Output Dimensions:', args.encoder_output_dim)
        print('Output Size:', len(task.target_dictionary))
        linear_layer = Linear(args.encoder_output_dim,
                              len(task.target_dictionary))
        return cls(encoder, linear_layer)
Beispiel #9
0
    def __init__(self,
                 args,
                 src_dictionary,
                 dst_dictionary,
                 src_embed_tokens,
                 dst_embed_tokens,
                 left_pad=True):

        super().__init__(None)

        self.src_dictionary = src_dictionary
        self.dst_dictionary = dst_dictionary
        self.encoder = TransformerEncoder(args,
                                          src_dictionary,
                                          src_embed_tokens,
                                          left_pad=left_pad)

        self.masked_encoder = TransformerEncoder(args,
                                                 dst_dictionary,
                                                 dst_embed_tokens,
                                                 left_pad=left_pad)
Beispiel #10
0
 def get_encoder(lang, lang_pair=None):
     if lang not in lang_encoders:
         if shared_encoder_embed_tokens is not None:
             encoder_embed_tokens = shared_encoder_embed_tokens
         elif args.share_all_langpair_embeddings:
             encoder_embed_tokens = lang_pair_embed[lang_pair]
         else:
             encoder_embed_tokens = build_embedding(
                 task.dicts[lang], args.encoder_embed_dim,
                 args.encoder_embed_path)
         lang_encoders[lang] = TransformerEncoder(
             args, task.dicts[lang], encoder_embed_tokens)
     return lang_encoders[lang]
 def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
     if is_encoder:
         if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
             return LatentTransformerEncoder(
                 args, lang_dict, embed_tokens, num_logits=len(langs)
             )
         else:
             return TransformerEncoder(args, lang_dict, embed_tokens)
     else:
         if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
             return LatentTransformerDecoder(
                 args, lang_dict, embed_tokens, num_logits=len(langs)
             )
         else:
             return TransformerDecoder(args, lang_dict, embed_tokens)
 def build_encoder(cls, args, src_dict, embed_tokens, token2components_map):
     if args.model_type == 'transformer':
         return TransformerEncoder(args, src_dict, embed_tokens)
     elif args.model_type == 'lstm':
         return TarcLSTMEncoder(
             dictionary=src_dict,
             embed_dim=args.encoder_embed_dim,
             hidden_size=args.encoder_hidden_dim,
             num_layers=args.encoder_layers,
             dropout_in=args.encoder_dropout_in,
             dropout_out=args.encoder_dropout_out,
             bidirectional=True,
             pretrained_embed=embed_tokens,
             max_source_positions=args.max_source_positions,
             token_map=token2components_map,
             granularity_flags=(args.token_sequences, args.char_sequences))
     else:
         raise NotImplementedError
Beispiel #13
0
 def build_encoder(cls, args, src_dict, embed_tokens):
     encoder = TransformerEncoder(args, src_dict, embed_tokens)
     if getattr(args, "apply_bert_init", False):
         encoder.apply(init_bert_params)
     return encoder
Beispiel #14
0
 def build_encoder(self, args, dictionary, embed_tokens):
     encoder = TransformerEncoder(args, dictionary, embed_tokens)
     encoder.apply(init_bert_params)
     return encoder
 def build_encoder(cls, args, src_dict, embed_tokens):
     return TransformerEncoder(args, src_dict, embed_tokens)
Beispiel #16
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder
Beispiel #17
0
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))
Beispiel #18
0
class UnsupervisedTranslation(Model):
    """
    This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
    uses the encoded representations to decode another sequence.  You can use this as the basis for
    a neural machine translation system, an abstractive summarization system, or any other common
    seq2seq problem.  The model here is simple, but should be a decent starting place for
    implementing recent models for these tasks.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))

    @overrides
    def forward(
        self,  # type: ignore
        lang_pair: List[str],
        lang1_tokens: Dict[str, torch.LongTensor] = None,
        lang1_golden: Dict[str, torch.LongTensor] = None,
        lang2_tokens: Dict[str, torch.LongTensor] = None,
        lang2_golden: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        """
        # detect training mode and what kind of task we need to compute
        if lang2_tokens is None and lang1_tokens is None:
            raise ConfigurationError(
                "source_tokens and target_tokens can not both be None")

        mode_training = self.training
        mode_validation = not self.training and lang2_tokens is not None  # change 'target_tokens' condition
        mode_prediction = lang2_tokens is None  # change 'target_tokens' condition

        lang_src, lang_tgt = lang_pair[0].split('-')

        if mode_training:
            # task types
            task_translation = False
            task_denoising = False
            task_backtranslation = False

            if lang_src == 'xx':
                task_backtranslation = True
            elif lang_src == lang_tgt:
                task_denoising = True
            elif lang_src != lang_tgt:
                task_translation = True
            else:
                raise ConfigurationError("All tasks are false")

        output_dict = {}
        if mode_training:

            if task_translation:
                loss = self._forward_seq2seq(lang_pair, lang1_tokens,
                                             lang2_tokens, lang2_golden)
                if self._bleu:
                    predicted_indices = self._sequence_generator_beam.generate(
                        [self._model], lang1_tokens,
                        self._get_true_pad_mask(lang1_tokens),
                        self._end_index_lang2)
                    predicted_strings = self._indices_to_strings(
                        predicted_indices)
                    golden_strings = self._indices_to_strings(
                        lang2_tokens["tokens"])
                    golden_strings = self._remove_pad_eos(golden_strings)
                    # print(golden_strings, predicted_strings)
                    self._bleu(corpus_bleu(golden_strings, predicted_strings))
            elif task_denoising:  # might need to split it into two blocks for interlingua loss
                loss = self._forward_seq2seq(lang_pair, lang1_tokens,
                                             lang2_tokens, lang2_golden)
            elif task_backtranslation:
                # our goal is also to learn from regular cross-entropy loss, but since we do not have source tokens,
                # we will generate them ourselves with current model
                langs_src = self._backtranslation_src_langs.copy()
                langs_src.remove(lang_tgt)
                bt_losses = {}
                for lang_src in langs_src:
                    curr_lang_pair = lang_src + "-" + lang_tgt
                    # TODO: require to pass target language to forward on encoder outputs
                    # We use greedy decoder because it was shown better for backtranslation
                    with torch.no_grad():
                        predicted_indices = self._sequence_generator_greedy.generate(
                            [self._model], lang2_tokens,
                            self._get_true_pad_mask(lang2_tokens),
                            self._end_index_lang2)
                    model_input = self._strings_to_batch(
                        self._indices_to_strings(predicted_indices),
                        lang2_tokens, lang2_golden, curr_lang_pair)
                    bt_losses['bt:' + curr_lang_pair] = self._forward_seq2seq(
                        **model_input)
            else:
                raise ConfigurationError("No task have been detected")

            if task_translation:
                loss = self._coeff_translation * loss
            elif task_denoising:
                loss = self._coeff_denoising * loss
            elif task_backtranslation:
                loss = 0
                for bt_loss in bt_losses.values():
                    loss += self._coeff_backtranslation * bt_loss

            output_dict["loss"] = loss

        elif mode_validation:
            output_dict["loss"] = self._coeff_translation * \
                                  self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden)
            if self._bleu:
                predicted_indices = self._sequence_generator_greedy.generate(
                    [self._model], lang1_tokens,
                    self._get_true_pad_mask(lang1_tokens),
                    self._end_index_lang2)
                predicted_strings = self._indices_to_strings(predicted_indices)
                golden_strings = self._indices_to_strings(
                    lang2_tokens["tokens"])
                golden_strings = self._remove_pad_eos(golden_strings)
                print(golden_strings, predicted_strings)
                self._bleu(corpus_bleu(golden_strings, predicted_strings))

        elif mode_prediction:
            # TODO: pass target language (in the fseq_encoder append embedded target language to the encoder out)
            predicted_indices = self._sequence_generator_beam.generate(
                [self._model], lang1_tokens,
                self._get_true_pad_mask(lang1_tokens), self._end_index_lang2)
            output_dict["predicted_indices"] = predicted_indices
            output_dict["predicted_strings"] = self._indices_to_strings(
                predicted_indices)

        return output_dict

    def _get_true_pad_mask(self, indexed_input):
        mask = util.get_text_field_mask(indexed_input)
        # TODO: account for cases when text field mask doesn't work, like BERT
        return mask

    def _remove_pad_eos(self, golden_strings):
        tmp = []
        for x in golden_strings:
            tmp.append(
                list(
                    filter(
                        lambda a: a != DEFAULT_PADDING_TOKEN and a !=
                        END_SYMBOL, x)))
        return tmp

    def _convert_to_sentences(self, golden_strings, predicted_strings):
        golden_strings_nopad = []
        for s in golden_strings:
            s_nopad = list(filter(lambda t: t != DEFAULT_PADDING_TOKEN, s))
            s_nopad = " ".join(s_nopad)
            golden_strings_nopad.append(s_nopad)
        predicted_strings = [" ".join(s) for s in predicted_strings]
        return golden_strings_nopad, predicted_strings

    def _forward_seq2seq(
            self, lang_pair: List[str], source_tokens: Dict[str,
                                                            torch.LongTensor],
            target_tokens: Dict[str, torch.LongTensor],
            target_golden: Dict[str,
                                torch.LongTensor]) -> Dict[str, torch.Tensor]:
        source_tokens_padding_mask = self._get_true_pad_mask(source_tokens)
        encoder_out = self._encoder.forward(source_tokens,
                                            source_tokens_padding_mask)
        logits, _ = self._decoder.forward(target_tokens["tokens"], encoder_out)
        loss = self._get_ce_loss(logits, target_golden)
        return loss

    def _get_ce_loss(self, logits, golden):
        target_mask = util.get_text_field_mask(golden)
        loss = util.sequence_cross_entropy_with_logits(
            logits,
            golden["golden_tokens"],
            target_mask,
            label_smoothing=self._label_smoothing)
        return loss

    def _indices_to_strings(self, indices: torch.Tensor):
        all_predicted_tokens = []
        for hyp in indices:
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    idx.item(), namespace=self._lang2_namespace) for idx in hyp
            ]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    def _strings_to_batch(self, source_tokens: List[List[str]],
                          target_tokens: Dict[str, torch.Tensor],
                          target_golden: Dict[str,
                                              torch.Tensor], lang_pair: str):
        """
        Converts list of sentences which are itself lists of strings into Batch
        suitable for passing into model's forward function.

        TODO: Make sure the right device (CPU/GPU) is used. Predicted tokens might get copied on
        CPU in `self.decode` method...
        """
        # convert source tokens into source tensor_dict
        instances = []
        lang_pairs = []
        for sentence in source_tokens:
            sentence = " ".join(sentence)
            instances.append(self._reader.string_to_instance(sentence))
            lang_pairs.append(lang_pair)

        source_batch = Batch(instances)
        source_batch.index_instances(self.vocab)
        source_batch = source_batch.as_tensor_dict()
        model_input = {
            "source_tokens": source_batch["tokens"],
            "target_golden": target_golden,
            "target_tokens": target_tokens,
            "lang_pair": lang_pairs
        }

        return model_input

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update({"BLEU": self._bleu.get_metric(reset=reset)})
        return all_metrics
Beispiel #19
0
    def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    'models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
                    w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args,
                                               bart_encoder.dictionary,
                                               bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(cfg, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, cfg.decoder_embed_dim)
        else:
            self.proj = None

        self.pad_token = cfg.pad_token
        self.mix_normalization_factor = cfg.mix_normalization_factor
Beispiel #20
0
class Wav2VecEncoder(FairseqEncoder):
    def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    'models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
                    w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args,
                                               bart_encoder.dictionary,
                                               bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
        elif getattr(cfg, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, cfg.decoder_embed_dim)
        else:
            self.proj = None

        self.pad_token = cfg.pad_token
        self.mix_normalization_factor = cfg.mix_normalization_factor

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, tbc=True, **kwargs):
        input_lengths = (1 - padding_mask.long()).sum(-1)
        output_length = torch.max(
            self.w2v_model._get_feat_extract_output_lengths(input_lengths))
        # print('output_lengths', output_length,  'self.pad_token', self.pad_token)
        # print('kwargs', kwargs['bart_input_tokens'].shape, kwargs['bart_input_tokens'].type())
        batch_size, ntoken = kwargs['bart_input_tokens'].shape
        bart_input = torch.zeros(batch_size, output_length).long().fill_(
            self.pad_token).to(kwargs['bart_input_tokens'])
        bart_input[:, :ntoken] = kwargs['bart_input_tokens']
        # print(bart_input, bart_input.shape)
        # raise
        w2v_args = {
            "source": source,
            "padding_mask": padding_mask,
            "mask": self.apply_mask and self.training,
        }

        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)

            if tbc:
                # B x T x C -> T x B x C
                x = x.transpose(0, 1)

        x = self.final_dropout(x)

        x_bart = self.bart_encoder(src_tokens=bart_input,
                                   src_lengths=None,
                                   token_embeddings=None,
                                   return_all_hiddens=False)

        if self.proj:
            x = self.proj(x)
        x_bart = x_bart['encoder_out'][0]
        # print('x.shape', x.shape, )
        # print('x_bart', x_bart['encoder_out'][0].shape)
        # print(x_bart['encoder_padding_mask'][0].shape)
        prob = torch.sigmoid(
            torch.FloatTensor(
                [self.num_updates / self.mix_normalization_factor])) * 2 - 1
        # n_mix = int(self.mix_rate * output_length)
        # indices = torch.randperm(output_length)[:n_mix]
        # print(n_mix, indices)
        # print(prob)
        # mask = torch.bernoulli(torch.full(x.shape, prob.item())).int().to(x)
        mask = torch.bernoulli(torch.full(x.shape[:1],
                                          prob.item()))[:, None, None].to(x)
        reverse_mask = 1 - mask
        x = x * mask + x_bart * reverse_mask
        # x_bart[indices,:,:] = x[indices,:,:]

        # print('self.num_updates', prob, self.num_updates)
        if self.num_updates % 1000 == 0:
            print('self.num_updates', prob, self.num_updates)

        return {
            "encoder_out": [x],  # T x B x C
            "encoder_padding_mask": [padding_mask],  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if len(encoder_out["encoder_out"]) == 0:
            new_encoder_out = []
        else:
            new_encoder_out = [
                encoder_out["encoder_out"][0].index_select(1, new_order)
            ]  # T x B x C

        if len(encoder_out["encoder_padding_mask"]) == 0:
            new_encoder_padding_mask = []
        else:
            new_encoder_padding_mask = [
                encoder_out["encoder_padding_mask"][0].index_select(
                    0, new_order)
            ]

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
        }

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict
Beispiel #21
0
class Wav2VecEncoder(FairseqEncoder):
    def __init__(self, cfg: Wav2Vec2BartConfig, tgt_dict=None, transform_embed=None, bart=None):
        self.apply_mask = cfg.apply_mask

        arg_overrides = {
            "dropout": cfg.dropout,
            "activation_dropout": cfg.activation_dropout,
            "dropout_input": cfg.dropout_input,
            "attention_dropout": cfg.attention_dropout,
            "mask_length": cfg.mask_length,
            "mask_prob": cfg.mask_prob,
            "mask_selection": cfg.mask_selection,
            "mask_other": cfg.mask_other,
            "no_mask_overlap": cfg.no_mask_overlap,
            "mask_channel_length": cfg.mask_channel_length,
            "mask_channel_prob": cfg.mask_channel_prob,
            "mask_channel_selection": cfg.mask_channel_selection,
            "mask_channel_other": cfg.mask_channel_other,
            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
            "encoder_layerdrop": cfg.layerdrop,
            "feature_grad_mult": cfg.feature_grad_mult,
        }

        if cfg.w2v_args is None:
            if os.path.isfile(os.path.join(cfg.w2v_path)):
                print('load wav2vec from cfg path')
                state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
            else:
                print('load wav2vec from relative path')
                state = checkpoint_utils.load_checkpoint_to_cpu('models/wav2vec_small.pt', arg_overrides)
            w2v_args = state.get("cfg", None)
            if w2v_args is None:
                w2v_args = convert_namespace_to_omegaconf(state["args"])
            cfg.w2v_args = w2v_args
        else:
            state = None
            w2v_args = cfg.w2v_args
            if isinstance(w2v_args, Namespace):
                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)

        assert cfg.normalize == w2v_args.task.normalize, (
            "Fine-tuning works best when data normalization is the same. "
            "Please check that --normalize is set or unset for both pre-training and here"
        )

        w2v_args.task.data = cfg.data
        task = tasks.setup_task(w2v_args.task)
        model = task.build_model(w2v_args.model)

        if state is not None and not cfg.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.model.encoder_embed_dim

        self.w2v_model = model

        self.final_dropout = nn.Dropout(cfg.final_dropout)
        self.freeze_finetune_updates = cfg.freeze_finetune_updates
        self.num_updates = 0

        self.bart_encoder = bart.model.encoder
        bart_encoder = bart.model.encoder
        self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens)
        self.bart_encoder.load_state_dict(bart_encoder.state_dict())
        self.fix_bart_encoder = cfg.fix_bart_encoder

        if self.fix_bart_encoder:
            print('fix bart encoder')
            for n, parameter in self.bart_encoder.named_parameters():
                parameter.requires_grad = False

        # if tgt_dict is not None:
        print('len(tgt_dict)', len(tgt_dict))
        self.proj = Linear(d, len(tgt_dict))
        # elif getattr(cfg, "decoder_embed_dim", d) != d:
        # self.proj = Linear(d, cfg.decoder_embed_dim)
        # else:
        #     self.proj = None

        # bart.model.encoder.embed_tokens.weight.shape
        # here assume wav2vec and bart have same hidden size
        self.bart_encoder.embed_tokens.weight.requires_grad_(cfg.bart_embedding_finetune)
        self.transform_embed = transform_embed
        self.emb = EmbeddingTransformed(self.bart_encoder.embed_tokens, self.transform_embed)
        # if fix bart embedding 

        self.pad_token = cfg.pad_token
        self.ctc_weight = cfg.ctc_weight
        self.ce_weight = cfg.ce_weight
        
        # self.mix_normalization_factor = cfg.mix_normalization_factor

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, tbc=True, **kwargs):
        # -----------transform embedding-----------
        target_tokens = kwargs['target_tokens']
        bart_emb = self.bart_encoder.embed_tokens.weight
        # transformed_emb = self.transform_embed(bart_emb.T).T

        # -----------wav2vec-----------
        w2v_args = {
            "source": source,
            "padding_mask": padding_mask,
            "mask": self.apply_mask and self.training,
        }

        # finetuning all without freeze
        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)
            if tbc:
                # B x T x C -> T x B x C
                x = x.transpose(0, 1)

        x_wav2vec = self.final_dropout(x) # hidden embedding
        logits_wav2vec = self.proj(x) # T x B x V
        
        # -----------pad predict tokens-----------
        # if ft:

        logit_lengths = (1 - padding_mask.long()).sum(-1) # B x T
        logit_preds = torch.argmax(logits_wav2vec, dim=-1) # B
        
        if tbc:
            logit_preds = logit_preds.transpose(0, 1) # B x T

        print('logits_wav2vec.shape, logit_preds.shape', logits_wav2vec.shape, logit_preds.shape, logit_preds)
        pred_idxs, pred_lengths = [], []
        for i, (y, length) in enumerate(zip(logit_preds, logit_lengths)):
            emb_idx = torch.stack([x[0] for x in groupby(y[:length])])
            pred_idxs.append(emb_idx)
            pred_lengths.append(len(emb_idx))
        
        max_len = max(pred_lengths)
        print('pred_lengths', pred_lengths, max_len)
        tokens_w2v = torch.zeros(len(logit_preds), max_len).long().fill_(self.pad_token)

        for i, pred_idx in enumerate(pred_idxs):
            tokens_w2v[i,:(len(pred_idx))] = pred_idx

        # use target_tokens if finetuning embbedding and transformation (not ft)
        # use tokens_w2v from wav2vec if fintuning
        if ft: # if finetune from prediction (after {freeze_finetune_updates} steps)
            bart_input = tokens_w2v
            bart_input_lengths = pred_lengths
            ctc_weight, ce_weight = self.ctc_weight, 1
        else: # initial steps, from ground truth
            bart_input = target_tokens
            bart_input_lengths = kwargs['target_token_lengths']
            ctc_weight, ce_weight = 1, 1
        token_emb = self.emb(bart_input)
        # token_emb = torch.index_select(transformed_emb, 0, bart_input.reshape(-1)).view(*bart_input.shape, -1)


        # feed token to bart encoder
        bart_encoder_output = self.bart_encoder(
            src_tokens=bart_input,
            src_lengths=bart_input_lengths,
            token_embeddings=token_emb, # pass in customized embedding
            return_all_hiddens=False,
        )

        # if self.num_updates % 1000 == 0:
        #     print('self.num_updates', self.num_updates)

        return {
            "encoder_out": bart_encoder_output['encoder_out'],  # T x B x C
            "encoder_padding_mask": bart_encoder_output['encoder_padding_mask'],  # B x T
            "wav2vec_logits": logits_wav2vec,  # T x B x C
            "wav2vec_padding_mask": padding_mask,
            "ctc_weight": ctc_weight,
            "ce_weight": ce_weight,
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if len(encoder_out["encoder_out"]) == 0:
            new_encoder_out = []
        else:
            new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] # T x B x C

        if len(encoder_out["encoder_padding_mask"]) == 0:
            new_encoder_padding_mask = []
        else:
            new_encoder_padding_mask = [
                encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
            ]

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
        }

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict