def build_submodel(cls, args, task, reverse=False): dual_transformer_small(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, 'max_source_positions', None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, 'max_target_positions', None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS if not reverse: src_dict, tgt_dict = task.source_dictionary, task.target_dictionary else: args.source_lang, args.target_lang = args.target_lang, args.source_lang tgt_dict, src_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError('--share-all-embeddings requires a joined dictionary') if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path): raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = build_embedding( tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens, src_dict=src_dict) return TransformerModel(args, encoder, decoder)
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, 'max_source_positions'): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if not hasattr(args, 'max_target_positions'): args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS tgt_dict = task.target_dictionary encoder_embed_speech = ASRFeature( cmvn=args.cmvn, n_mels=args.fbank_dim, dropout=args.dropout, sample_rate=task. sample_rate, # NOTE: assumes load_dataset is called before build_model n_fft=args.stft_dim, stride=args.stft_stride, n_subsample=args.encoder_subsample_layers, odim=args.encoder_embed_dim, ) def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim, args.decoder_embed_path) setattr(encoder_embed_speech, "padding_idx", -1) # decoder_embed_tokens.padding_idx) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) encoder = cls.build_encoder(args, encoder_embed_speech) return TransformerModel(encoder, decoder)
def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))