def build_encoder(cls, args, src_dict, embed_tokens, src_factor_embed_tokens=None): if src_factor_embed_tokens: src_encoder = SrcFactorEncoder(args, src_dict, embed_tokens, src_factor_embed_tokens) else: src_encoder = TransformerEncoder(args, src_dict, embed_tokens) if getattr(args, "apply_bert_init", False): src_encoder.apply(init_bert_params) if getattr(args, "share_encoder", False): mt_encoder = src_encoder else: if src_factor_embed_tokens: mt_encoder = SrcFactorEncoder(args, src_dict, embed_tokens, src_factor_embed_tokens) else: mt_encoder = TransformerEncoder(args, src_dict, embed_tokens) if getattr(args, "apply_bert_init", False): mt_encoder.apply(init_bert_params) encoder = MultisourceEncoder(src_encoder, {'mt': mt_encoder}, order=['mt']) return encoder
def build_encoder(self, cfg, dictionary, embed_tokens): encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True) encoder.apply(init_bert_params) return encoder
def build_text_encoder(cls, args, src_dictionary, spch_encoder): if args.encoder_shared_layers > 0: mx_shared_layers = ( args.speech_encoder_layers if args.speech_encoder_layers < args.text_encoder_layers else args.text_encoder_layers) args.encoder_shared_layers = ( args.encoder_shared_layers if args.encoder_shared_layers <= mx_shared_layers else mx_shared_layers) cfg = { "encoder_embed_dim": args.encoder_text_embed_dim, "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, "encoder_layers": args.text_encoder_layers, "encoder_layerdrop": args.encoder_layerdrop, "encoder_attention_heads": args.encoder_attention_heads, "encoder_learned_pos": args.encoder_learned_pos, "max_source_positions": args.max_source_positions, "dropout": args.dropout, "encoder_normalize_before": args.encoder_normalize_before, "activation_dropout": args.activation_dropout, "attention_dropout": args.attention_dropout, "activation_fn": args.activation_fn, "adaptive_input": args.adaptive_input, "no_token_positional_embeddings": args.no_token_positional_embeddings, "no_scale_embedding": args.no_scale_embedding, "quant_noise_pq": args.quant_noise_pq, } model_args = namedtuple("args", cfg.keys())(*cfg.values()) enc_emb = nn.Embedding(len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()) text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) if args.add_speech_eos: spch_encoder = spch_encoder.encoder if args.encoder_shared_layers > 0: text_encoder.layer_norm = cls.set_shared_layer( args.encoder_shared_layer_level, text_encoder.layer_norm, spch_encoder.layer_norm, ) for i, ly in enumerate( spch_encoder. transformer_layers[-args.encoder_shared_layers:]): ly_id = i + args.text_encoder_layers - args.encoder_shared_layers if not isinstance(text_encoder.layers[ly_id], type(ly)): if text_encoder.layers[ly_id]._get_name() not in ( 'TransformerEncoderLayerBase', 'TransformerEncoderLayer'): raise ValueError( "The shared layers are expected from the same class" ) text_encoder.layers[ly_id] = cls.set_shared_layer( args.encoder_shared_layer_level, text_encoder.layers[ly_id], ly, ) return text_encoder
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb if args.share_all_embeddings: if src_dict != tgt_dict: raise RuntimeError( '--share-all-embeddings requires a joined dictionary') if args.encoder_embed_dim != args.decoder_embed_dim: raise RuntimeError( '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim' ) if args.decoder_embed_path and (args.decoder_embed_path != args.encoder_embed_path): raise RuntimeError( '--share-all-embeddings not compatible with --decoder-embed-path' ) encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim, args.encoder_embed_path) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim, args.encoder_embed_path) decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim, args.decoder_embed_path) encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens) encoder2 = TransformerEncoder(args, tgt_dict, decoder_embed_token) decoder2 = TransformerDecoder(args, src_dict, encoder_embed_tokens) return TransformerDualModel(encoder, decoder, encoder2, decoder2)
def build_text_encoder(cls, args, src_dictionary): enc_emb = nn.Embedding(len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()) model_args = cls.update_transformer_encoder_cfg( args, {"encoder_layers": args.text_encoder_layers}) text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) return text_encoder
def build_encoder(cls, args, src_dict, embed_tokens): if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select: return HeadSelectionTransformerEncoder( args, src_dict, embed_tokens ) else: return TransformerEncoder(args, src_dict, embed_tokens)
def get_encoder(lang): if lang not in lang_encoders: if shared_encoder_embed_tokens is not None: encoder_embed_tokens = shared_encoder_embed_tokens else: encoder_embed_tokens = build_embedding( task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path ) lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens) return lang_encoders[lang]
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_lm_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if getattr(args, 'max_target_positions', None) is None: args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) if args.character_embeddings: embed_tokens = CharacterTokenEmbedder( task.source_dictionary, eval(args.character_filters), args.character_embedding_dim, args.encoder_embed_dim, args.char_embedder_highway_layers, ) elif args.adaptive_input: embed_tokens = AdaptiveInput( len(task.source_dictionary), task.source_dictionary.pad(), args.encoder_input_dim, args.adaptive_input_factor, args.encoder_embed_dim, options.eval_str_list(args.adaptive_input_cutoff, type=int), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: embed_tokens = cls.build_embedding(args, task.source_dictionary, args.encoder_input_dim) if args.tie_adaptive_weights: assert args.adaptive_input assert args.adaptive_input_factor == args.adaptive_softmax_factor assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format( args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) assert args.encoder_input_dim == args.encoder_output_dim encoder = TransformerEncoder( args, task.target_dictionary, embed_tokens, ) print('Encoder Output Dimensions:', args.encoder_output_dim) print('Output Size:', len(task.target_dictionary)) linear_layer = Linear(args.encoder_output_dim, len(task.target_dictionary)) return cls(encoder, linear_layer)
def __init__(self, args, src_dictionary, dst_dictionary, src_embed_tokens, dst_embed_tokens, left_pad=True): super().__init__(None) self.src_dictionary = src_dictionary self.dst_dictionary = dst_dictionary self.encoder = TransformerEncoder(args, src_dictionary, src_embed_tokens, left_pad=left_pad) self.masked_encoder = TransformerEncoder(args, dst_dictionary, dst_embed_tokens, left_pad=left_pad)
def get_encoder(lang, lang_pair=None): if lang not in lang_encoders: if shared_encoder_embed_tokens is not None: encoder_embed_tokens = shared_encoder_embed_tokens elif args.share_all_langpair_embeddings: encoder_embed_tokens = lang_pair_embed[lang_pair] else: encoder_embed_tokens = build_embedding( task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path) lang_encoders[lang] = TransformerEncoder( args, task.dicts[lang], encoder_embed_tokens) return lang_encoders[lang]
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): if is_encoder: if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: return LatentTransformerEncoder( args, lang_dict, embed_tokens, num_logits=len(langs) ) else: return TransformerEncoder(args, lang_dict, embed_tokens) else: if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer: return LatentTransformerDecoder( args, lang_dict, embed_tokens, num_logits=len(langs) ) else: return TransformerDecoder(args, lang_dict, embed_tokens)
def build_encoder(cls, args, src_dict, embed_tokens, token2components_map): if args.model_type == 'transformer': return TransformerEncoder(args, src_dict, embed_tokens) elif args.model_type == 'lstm': return TarcLSTMEncoder( dictionary=src_dict, embed_dim=args.encoder_embed_dim, hidden_size=args.encoder_hidden_dim, num_layers=args.encoder_layers, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, bidirectional=True, pretrained_embed=embed_tokens, max_source_positions=args.max_source_positions, token_map=token2components_map, granularity_flags=(args.token_sequences, args.char_sequences)) else: raise NotImplementedError
def build_encoder(cls, args, src_dict, embed_tokens): encoder = TransformerEncoder(args, src_dict, embed_tokens) if getattr(args, "apply_bert_init", False): encoder.apply(init_bert_params) return encoder
def build_encoder(self, args, dictionary, embed_tokens): encoder = TransformerEncoder(args, dictionary, embed_tokens) encoder.apply(init_bert_params) return encoder
def build_encoder(cls, args, src_dict, embed_tokens): return TransformerEncoder(args, src_dict, embed_tokens)
def build_encoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_source_positions = 1024 enc_emb = nn.Embedding( len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() ) text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) spch_encoder = Wav2VecEncoderWithAdaptor(args) if getattr(args, "load_pretrained_mbart_from", None): text_encoder = checkpoint_utils.load_pretrained_component_from_model( component=text_encoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "stack_w2v_mbart_encoder", False): assert getattr(args, "share_w2v_text_encoder", False) is False spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): text_encoder.layer_norm = None spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "share_w2v_text_encoder", False): spch_encoder = SharedEncoder( spch_encoder.w2v_encoder, text_encoder, spch_encoder.adaptor, args.shared_w2v_layers, ) for k, p in spch_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_w2v_params" ) and need_finetuning(args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False for k, p in text_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_encoder_params" ) and need_finetuning( args.finetune_mbart_encoder_params, k ): p.requires_grad = True else: p.requires_grad = False cross_attentive_loss_before_last_layer = ( 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 ) encoder = DualInputEncoder( args, spch_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) return encoder
def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20))
class UnsupervisedTranslation(Model): """ This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then uses the encoded representations to decode another sequence. You can use this as the basis for a neural machine translation system, an abstractive summarization system, or any other common seq2seq problem. The model here is simple, but should be a decent starting place for implementing recent models for these tasks. Parameters ---------- vocab : ``Vocabulary``, required Vocabulary containing source and target vocabularies. They may be under the same namespace (`tokens`) or the target tokens can have a different namespace, in which case it needs to be specified as `target_namespace`. source_embedder : ``TextFieldEmbedder``, required Embedder for source side sequences encoder : ``Seq2SeqEncoder``, required The encoder of the "encoder/decoder" model max_decoding_steps : ``int`` Maximum length of decoded sequences. target_namespace : ``str``, optional (default = 'target_tokens') If the target side vocabulary is different from the source side's, you need to specify the target's namespace here. If not, we'll assume it is "tokens", which is also the default choice for the source side, and this might cause them to share vocabularies. target_embedding_dim : ``int``, optional (default = source_embedding_dim) You can specify an embedding dimensionality for the target side. If not, we'll use the same value as the source embedder's. attention : ``Attention``, optional (default = None) If you want to use attention to get a dynamic summary of the encoder outputs at each step of decoding, this is the function used to compute similarity between the decoder hidden state and encoder outputs. attention_function: ``SimilarityFunction``, optional (default = None) This is if you want to use the legacy implementation of attention. This will be deprecated since it consumes more memory than the specialized attention modules. beam_size : ``int``, optional (default = None) Width of the beam for beam search. If not specified, greedy decoding is used. scheduled_sampling_ratio : ``float``, optional (default = 0.) At each timestep during training, we sample a random number between 0 and 1, and if it is not less than this value, we use the ground truth labels for the whole batch. Else, we use the predictions from the previous time step for the whole batch. If this value is 0.0 (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not using target side ground truth labels. See the following paper for more information: `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 2015 <https://arxiv.org/abs/1506.03099>`_. use_bleu : ``bool``, optional (default = True) If True, the BLEU metric will be calculated during validation. """ def __init__(self, vocab: Vocabulary, dataset_reader: DatasetReader, source_embedder: TextFieldEmbedder, lang2_namespace: str = "tokens", use_bleu: bool = True) -> None: super().__init__(vocab) self._lang1_namespace = lang2_namespace # TODO: DO NOT HARDCODE IT self._lang2_namespace = lang2_namespace # TODO: do not hardcore this self._backtranslation_src_langs = ["en", "ru"] self._coeff_denoising = 1 self._coeff_backtranslation = 1 self._coeff_translation = 1 self._label_smoothing = 0.1 self._pad_index_lang1 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang1_namespace) self._oov_index_lang1 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang1_namespace) self._end_index_lang1 = self.vocab.get_token_index( END_SYMBOL, self._lang1_namespace) self._pad_index_lang2 = vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._lang2_namespace) self._oov_index_lang2 = vocab.get_token_index(DEFAULT_OOV_TOKEN, self._lang2_namespace) self._end_index_lang2 = self.vocab.get_token_index( END_SYMBOL, self._lang2_namespace) self._reader = dataset_reader self._langs_list = self._reader._langs_list self._ae_steps = self._reader._ae_steps self._bt_steps = self._reader._bt_steps self._para_steps = self._reader._para_steps if use_bleu: self._bleu = Average() else: self._bleu = None args = ArgsStub() transformer_iwslt_de_en(args) # build encoder if not hasattr(args, 'max_source_positions'): args.max_source_positions = 1024 if not hasattr(args, 'max_target_positions'): args.max_target_positions = 1024 # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Dense embedding of vocab words in the target space. num_tokens_lang1 = self.vocab.get_vocab_size(self._lang1_namespace) num_tokens_lang2 = self.vocab.get_vocab_size(self._lang2_namespace) args.share_decoder_input_output_embed = False # TODO implement shared embeddings lang1_dict = DictStub(num_tokens=num_tokens_lang1, pad=self._pad_index_lang1, unk=self._oov_index_lang1, eos=self._end_index_lang1) lang2_dict = DictStub(num_tokens=num_tokens_lang2, pad=self._pad_index_lang2, unk=self._oov_index_lang2, eos=self._end_index_lang2) # instantiate fairseq classes emb_golden_tokens = FairseqEmbedding(num_tokens_lang2, args.decoder_embed_dim, self._pad_index_lang2) self._encoder = TransformerEncoder(args, lang1_dict, self._source_embedder) self._decoder = TransformerDecoder(args, lang2_dict, emb_golden_tokens) self._model = TransformerModel(self._encoder, self._decoder) # TODO: do not hardcode max_len_b and beam size self._sequence_generator_greedy = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=1, max_len_b=20)) self._sequence_generator_beam = FairseqBeamSearchWrapper( SequenceGenerator(tgt_dict=lang2_dict, beam_size=7, max_len_b=20)) @overrides def forward( self, # type: ignore lang_pair: List[str], lang1_tokens: Dict[str, torch.LongTensor] = None, lang1_golden: Dict[str, torch.LongTensor] = None, lang2_tokens: Dict[str, torch.LongTensor] = None, lang2_golden: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ """ # detect training mode and what kind of task we need to compute if lang2_tokens is None and lang1_tokens is None: raise ConfigurationError( "source_tokens and target_tokens can not both be None") mode_training = self.training mode_validation = not self.training and lang2_tokens is not None # change 'target_tokens' condition mode_prediction = lang2_tokens is None # change 'target_tokens' condition lang_src, lang_tgt = lang_pair[0].split('-') if mode_training: # task types task_translation = False task_denoising = False task_backtranslation = False if lang_src == 'xx': task_backtranslation = True elif lang_src == lang_tgt: task_denoising = True elif lang_src != lang_tgt: task_translation = True else: raise ConfigurationError("All tasks are false") output_dict = {} if mode_training: if task_translation: loss = self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) if self._bleu: predicted_indices = self._sequence_generator_beam.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) predicted_strings = self._indices_to_strings( predicted_indices) golden_strings = self._indices_to_strings( lang2_tokens["tokens"]) golden_strings = self._remove_pad_eos(golden_strings) # print(golden_strings, predicted_strings) self._bleu(corpus_bleu(golden_strings, predicted_strings)) elif task_denoising: # might need to split it into two blocks for interlingua loss loss = self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) elif task_backtranslation: # our goal is also to learn from regular cross-entropy loss, but since we do not have source tokens, # we will generate them ourselves with current model langs_src = self._backtranslation_src_langs.copy() langs_src.remove(lang_tgt) bt_losses = {} for lang_src in langs_src: curr_lang_pair = lang_src + "-" + lang_tgt # TODO: require to pass target language to forward on encoder outputs # We use greedy decoder because it was shown better for backtranslation with torch.no_grad(): predicted_indices = self._sequence_generator_greedy.generate( [self._model], lang2_tokens, self._get_true_pad_mask(lang2_tokens), self._end_index_lang2) model_input = self._strings_to_batch( self._indices_to_strings(predicted_indices), lang2_tokens, lang2_golden, curr_lang_pair) bt_losses['bt:' + curr_lang_pair] = self._forward_seq2seq( **model_input) else: raise ConfigurationError("No task have been detected") if task_translation: loss = self._coeff_translation * loss elif task_denoising: loss = self._coeff_denoising * loss elif task_backtranslation: loss = 0 for bt_loss in bt_losses.values(): loss += self._coeff_backtranslation * bt_loss output_dict["loss"] = loss elif mode_validation: output_dict["loss"] = self._coeff_translation * \ self._forward_seq2seq(lang_pair, lang1_tokens, lang2_tokens, lang2_golden) if self._bleu: predicted_indices = self._sequence_generator_greedy.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) predicted_strings = self._indices_to_strings(predicted_indices) golden_strings = self._indices_to_strings( lang2_tokens["tokens"]) golden_strings = self._remove_pad_eos(golden_strings) print(golden_strings, predicted_strings) self._bleu(corpus_bleu(golden_strings, predicted_strings)) elif mode_prediction: # TODO: pass target language (in the fseq_encoder append embedded target language to the encoder out) predicted_indices = self._sequence_generator_beam.generate( [self._model], lang1_tokens, self._get_true_pad_mask(lang1_tokens), self._end_index_lang2) output_dict["predicted_indices"] = predicted_indices output_dict["predicted_strings"] = self._indices_to_strings( predicted_indices) return output_dict def _get_true_pad_mask(self, indexed_input): mask = util.get_text_field_mask(indexed_input) # TODO: account for cases when text field mask doesn't work, like BERT return mask def _remove_pad_eos(self, golden_strings): tmp = [] for x in golden_strings: tmp.append( list( filter( lambda a: a != DEFAULT_PADDING_TOKEN and a != END_SYMBOL, x))) return tmp def _convert_to_sentences(self, golden_strings, predicted_strings): golden_strings_nopad = [] for s in golden_strings: s_nopad = list(filter(lambda t: t != DEFAULT_PADDING_TOKEN, s)) s_nopad = " ".join(s_nopad) golden_strings_nopad.append(s_nopad) predicted_strings = [" ".join(s) for s in predicted_strings] return golden_strings_nopad, predicted_strings def _forward_seq2seq( self, lang_pair: List[str], source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor], target_golden: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: source_tokens_padding_mask = self._get_true_pad_mask(source_tokens) encoder_out = self._encoder.forward(source_tokens, source_tokens_padding_mask) logits, _ = self._decoder.forward(target_tokens["tokens"], encoder_out) loss = self._get_ce_loss(logits, target_golden) return loss def _get_ce_loss(self, logits, golden): target_mask = util.get_text_field_mask(golden) loss = util.sequence_cross_entropy_with_logits( logits, golden["golden_tokens"], target_mask, label_smoothing=self._label_smoothing) return loss def _indices_to_strings(self, indices: torch.Tensor): all_predicted_tokens = [] for hyp in indices: predicted_tokens = [ self.vocab.get_token_from_index( idx.item(), namespace=self._lang2_namespace) for idx in hyp ] all_predicted_tokens.append(predicted_tokens) return all_predicted_tokens def _strings_to_batch(self, source_tokens: List[List[str]], target_tokens: Dict[str, torch.Tensor], target_golden: Dict[str, torch.Tensor], lang_pair: str): """ Converts list of sentences which are itself lists of strings into Batch suitable for passing into model's forward function. TODO: Make sure the right device (CPU/GPU) is used. Predicted tokens might get copied on CPU in `self.decode` method... """ # convert source tokens into source tensor_dict instances = [] lang_pairs = [] for sentence in source_tokens: sentence = " ".join(sentence) instances.append(self._reader.string_to_instance(sentence)) lang_pairs.append(lang_pair) source_batch = Batch(instances) source_batch.index_instances(self.vocab) source_batch = source_batch.as_tensor_dict() model_input = { "source_tokens": source_batch["tokens"], "target_golden": target_golden, "target_tokens": target_tokens, "lang_pair": lang_pairs } return model_input @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu and not self.training: all_metrics.update({"BLEU": self._bleu.get_metric(reset=reset)}) return all_metrics
def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu( 'models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.bart_encoder = bart.model.encoder bart_encoder = bart.model.encoder self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens) self.bart_encoder.load_state_dict(bart_encoder.state_dict()) self.fix_bart_encoder = cfg.fix_bart_encoder if self.fix_bart_encoder: print('fix bart encoder') for n, parameter in self.bart_encoder.named_parameters(): parameter.requires_grad = False if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None self.pad_token = cfg.pad_token self.mix_normalization_factor = cfg.mix_normalization_factor
class Wav2VecEncoder(FairseqEncoder): def __init__(self, cfg: WavBart2BartConfig, tgt_dict=None, bart=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu( 'models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.bart_encoder = bart.model.encoder bart_encoder = bart.model.encoder self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens) self.bart_encoder.load_state_dict(bart_encoder.state_dict()) self.fix_bart_encoder = cfg.fix_bart_encoder if self.fix_bart_encoder: print('fix bart encoder') for n, parameter in self.bart_encoder.named_parameters(): parameter.requires_grad = False if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None self.pad_token = cfg.pad_token self.mix_normalization_factor = cfg.mix_normalization_factor def set_num_updates(self, num_updates): """Set the number of parameters updates.""" super().set_num_updates(num_updates) self.num_updates = num_updates def forward(self, source, padding_mask, tbc=True, **kwargs): input_lengths = (1 - padding_mask.long()).sum(-1) output_length = torch.max( self.w2v_model._get_feat_extract_output_lengths(input_lengths)) # print('output_lengths', output_length, 'self.pad_token', self.pad_token) # print('kwargs', kwargs['bart_input_tokens'].shape, kwargs['bart_input_tokens'].type()) batch_size, ntoken = kwargs['bart_input_tokens'].shape bart_input = torch.zeros(batch_size, output_length).long().fill_( self.pad_token).to(kwargs['bart_input_tokens']) bart_input[:, :ntoken] = kwargs['bart_input_tokens'] # print(bart_input, bart_input.shape) # raise w2v_args = { "source": source, "padding_mask": padding_mask, "mask": self.apply_mask and self.training, } ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): x, padding_mask = self.w2v_model.extract_features(**w2v_args) if tbc: # B x T x C -> T x B x C x = x.transpose(0, 1) x = self.final_dropout(x) x_bart = self.bart_encoder(src_tokens=bart_input, src_lengths=None, token_embeddings=None, return_all_hiddens=False) if self.proj: x = self.proj(x) x_bart = x_bart['encoder_out'][0] # print('x.shape', x.shape, ) # print('x_bart', x_bart['encoder_out'][0].shape) # print(x_bart['encoder_padding_mask'][0].shape) prob = torch.sigmoid( torch.FloatTensor( [self.num_updates / self.mix_normalization_factor])) * 2 - 1 # n_mix = int(self.mix_rate * output_length) # indices = torch.randperm(output_length)[:n_mix] # print(n_mix, indices) # print(prob) # mask = torch.bernoulli(torch.full(x.shape, prob.item())).int().to(x) mask = torch.bernoulli(torch.full(x.shape[:1], prob.item()))[:, None, None].to(x) reverse_mask = 1 - mask x = x * mask + x_bart * reverse_mask # x_bart[indices,:,:] = x[indices,:,:] # print('self.num_updates', prob, self.num_updates) if self.num_updates % 1000 == 0: print('self.num_updates', prob, self.num_updates) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [padding_mask], # B x T } def reorder_encoder_out(self, encoder_out, new_order): if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] # T x B x C if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T } def max_positions(self): """Maximum input length supported by the encoder.""" return None def upgrade_state_dict_named(self, state_dict, name): return state_dict
class Wav2VecEncoder(FairseqEncoder): def __init__(self, cfg: Wav2Vec2BartConfig, tgt_dict=None, transform_embed=None, bart=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu('models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.bart_encoder = bart.model.encoder bart_encoder = bart.model.encoder self.bart_encoder = TransformerEncoder(bart_encoder.args, bart_encoder.dictionary, bart_encoder.embed_tokens) self.bart_encoder.load_state_dict(bart_encoder.state_dict()) self.fix_bart_encoder = cfg.fix_bart_encoder if self.fix_bart_encoder: print('fix bart encoder') for n, parameter in self.bart_encoder.named_parameters(): parameter.requires_grad = False # if tgt_dict is not None: print('len(tgt_dict)', len(tgt_dict)) self.proj = Linear(d, len(tgt_dict)) # elif getattr(cfg, "decoder_embed_dim", d) != d: # self.proj = Linear(d, cfg.decoder_embed_dim) # else: # self.proj = None # bart.model.encoder.embed_tokens.weight.shape # here assume wav2vec and bart have same hidden size self.bart_encoder.embed_tokens.weight.requires_grad_(cfg.bart_embedding_finetune) self.transform_embed = transform_embed self.emb = EmbeddingTransformed(self.bart_encoder.embed_tokens, self.transform_embed) # if fix bart embedding self.pad_token = cfg.pad_token self.ctc_weight = cfg.ctc_weight self.ce_weight = cfg.ce_weight # self.mix_normalization_factor = cfg.mix_normalization_factor def set_num_updates(self, num_updates): """Set the number of parameters updates.""" super().set_num_updates(num_updates) self.num_updates = num_updates def forward(self, source, padding_mask, tbc=True, **kwargs): # -----------transform embedding----------- target_tokens = kwargs['target_tokens'] bart_emb = self.bart_encoder.embed_tokens.weight # transformed_emb = self.transform_embed(bart_emb.T).T # -----------wav2vec----------- w2v_args = { "source": source, "padding_mask": padding_mask, "mask": self.apply_mask and self.training, } # finetuning all without freeze ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): x, padding_mask = self.w2v_model.extract_features(**w2v_args) if tbc: # B x T x C -> T x B x C x = x.transpose(0, 1) x_wav2vec = self.final_dropout(x) # hidden embedding logits_wav2vec = self.proj(x) # T x B x V # -----------pad predict tokens----------- # if ft: logit_lengths = (1 - padding_mask.long()).sum(-1) # B x T logit_preds = torch.argmax(logits_wav2vec, dim=-1) # B if tbc: logit_preds = logit_preds.transpose(0, 1) # B x T print('logits_wav2vec.shape, logit_preds.shape', logits_wav2vec.shape, logit_preds.shape, logit_preds) pred_idxs, pred_lengths = [], [] for i, (y, length) in enumerate(zip(logit_preds, logit_lengths)): emb_idx = torch.stack([x[0] for x in groupby(y[:length])]) pred_idxs.append(emb_idx) pred_lengths.append(len(emb_idx)) max_len = max(pred_lengths) print('pred_lengths', pred_lengths, max_len) tokens_w2v = torch.zeros(len(logit_preds), max_len).long().fill_(self.pad_token) for i, pred_idx in enumerate(pred_idxs): tokens_w2v[i,:(len(pred_idx))] = pred_idx # use target_tokens if finetuning embbedding and transformation (not ft) # use tokens_w2v from wav2vec if fintuning if ft: # if finetune from prediction (after {freeze_finetune_updates} steps) bart_input = tokens_w2v bart_input_lengths = pred_lengths ctc_weight, ce_weight = self.ctc_weight, 1 else: # initial steps, from ground truth bart_input = target_tokens bart_input_lengths = kwargs['target_token_lengths'] ctc_weight, ce_weight = 1, 1 token_emb = self.emb(bart_input) # token_emb = torch.index_select(transformed_emb, 0, bart_input.reshape(-1)).view(*bart_input.shape, -1) # feed token to bart encoder bart_encoder_output = self.bart_encoder( src_tokens=bart_input, src_lengths=bart_input_lengths, token_embeddings=token_emb, # pass in customized embedding return_all_hiddens=False, ) # if self.num_updates % 1000 == 0: # print('self.num_updates', self.num_updates) return { "encoder_out": bart_encoder_output['encoder_out'], # T x B x C "encoder_padding_mask": bart_encoder_output['encoder_padding_mask'], # B x T "wav2vec_logits": logits_wav2vec, # T x B x C "wav2vec_padding_mask": padding_mask, "ctc_weight": ctc_weight, "ce_weight": ce_weight, } def reorder_encoder_out(self, encoder_out, new_order): if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] # T x B x C if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select(0, new_order) ] return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T } def max_positions(self): """Maximum input length supported by the encoder.""" return None def upgrade_state_dict_named(self, state_dict, name): return state_dict