def transformer_iwslt_de_en_dist(args): transformer_iwslt_de_en(args)
def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) base_monotonic_architecture(args)
def transformer_unidirectional_iwslt_de_en(args): transformer_iwslt_de_en(args)
def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))