def build_submodel(cls, args, task, reverse=False):
        dual_transformer_small(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, 'max_source_positions', None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, 'max_target_positions', None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        if not reverse:
            src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        else:
            args.source_lang, args.target_lang = args.target_lang, args.source_lang
            tgt_dict, src_dict  = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError('--share-all-embeddings requires a joined dictionary')
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
            if args.decoder_embed_path and (
                    args.decoder_embed_path != args.encoder_embed_path):
                raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens, src_dict=src_dict)
        return TransformerModel(args, encoder, decoder)
Ejemplo n.º 2
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        encoder_embed_speech = ASRFeature(
            cmvn=args.cmvn,
            n_mels=args.fbank_dim,
            dropout=args.dropout,
            sample_rate=task.
            sample_rate,  # NOTE: assumes load_dataset is called before build_model
            n_fft=args.stft_dim,
            stride=args.stft_stride,
            n_subsample=args.encoder_subsample_layers,
            odim=args.encoder_embed_dim,
        )

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        decoder_embed_tokens = build_embedding(tgt_dict,
                                               args.decoder_embed_dim,
                                               args.decoder_embed_path)

        setattr(encoder_embed_speech, "padding_idx",
                -1)  # decoder_embed_tokens.padding_idx)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        encoder = cls.build_encoder(args, encoder_embed_speech)
        return TransformerModel(encoder, decoder)
Ejemplo n.º 3
0
    def __init__(self,
                 vocab: Vocabulary,
                 dataset_reader: DatasetReader,
                 source_embedder: TextFieldEmbedder,
                 lang2_namespace: str = "tokens",
                 use_bleu: bool = True) -> None:
        super().__init__(vocab)
        self._lang1_namespace = lang2_namespace  # TODO: DO NOT HARDCODE IT
        self._lang2_namespace = lang2_namespace

        # TODO: do not hardcore this
        self._backtranslation_src_langs = ["en", "ru"]
        self._coeff_denoising = 1
        self._coeff_backtranslation = 1
        self._coeff_translation = 1

        self._label_smoothing = 0.1

        self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang1_namespace)
        self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang1_namespace)
        self._end_index_lang1 = self.vocab.get_token_index(
            END_SYMBOL, self._lang1_namespace)

        self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                      self._lang2_namespace)
        self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                      self._lang2_namespace)
        self._end_index_lang2 = self.vocab.get_token_index(
            END_SYMBOL, self._lang2_namespace)

        self._reader = dataset_reader
        self._langs_list = self._reader._langs_list
        self._ae_steps = self._reader._ae_steps
        self._bt_steps = self._reader._bt_steps
        self._para_steps = self._reader._para_steps

        if use_bleu:
            self._bleu = Average()
        else:
            self._bleu = None

        args = ArgsStub()

        transformer_iwslt_de_en(args)

        # build encoder
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Dense embedding of vocab words in the target space.
        num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace)
        num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace)

        args.share_decoder_input_output_embed = False  # TODO implement shared embeddings

        lang1_dict = DictStub(num_tokens=num_tokens_lang1,
                              pad=self._pad_index_lang1,
                              unk=self._oov_index_lang1,
                              eos=self._end_index_lang1)

        lang2_dict = DictStub(num_tokens=num_tokens_lang2,
                              pad=self._pad_index_lang2,
                              unk=self._oov_index_lang2,
                              eos=self._end_index_lang2)

        # instantiate fairseq classes
        emb_golden_tokens = FairseqEmbedding(num_tokens_lang2,
                                             args.decoder_embed_dim,
                                             self._pad_index_lang2)

        self._encoder = TransformerEncoder(args, lang1_dict,
                                           self._source_embedder)
        self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens)
        self._model = TransformerModel(self._encoder, self._decoder)

        # TODO: do not hardcode max_len_b and beam size
        self._sequence_generator_greedy = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20))
        self._sequence_generator_beam = FairseqBeamSearchWrapper(
            SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))