def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_lm_architecture(args) if safe_hasattr(args, "max_target_positions") and not safe_hasattr( args, "tokens_per_sample" ): args.tokens_per_sample = args.max_target_positions decoder = FConvDecoder( dictionary=task.target_dictionary, embed_dim=args.decoder_embed_dim, convolutions=eval(args.decoder_layers), out_embed_dim=args.decoder_embed_dim, attention=eval(args.decoder_attention), dropout=args.dropout, max_positions=args.tokens_per_sample, share_embed=False, positional_embeddings=False, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == "adaptive_loss" else None ), adaptive_softmax_dropout=args.adaptive_softmax_dropout, ) return FConvLanguageModel(decoder)
def build_decoder(cls, args, task, embed_tokens): if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or ( safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select): return HeadSelectionTransformerDecoderScriptable( args, task.target_dictionary, embed_tokens) else: return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not safe_hasattr(args, "max_source_positions"): args.max_source_positions = 1024 if not safe_hasattr(args, "max_target_positions"): args.max_target_positions = 1024 src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb if args.share_all_embeddings: if src_dict != tgt_dict: raise RuntimeError( "--share-all-embeddings requires a joined dictionary" ) if args.encoder_embed_dim != args.decoder_embed_dim: raise RuntimeError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise RuntimeError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = build_embedding( tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens) decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) return LightConvModel(encoder, decoder)
def build_encoder(cls, args, src_dict, embed_tokens): if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select: return HeadSelectionTransformerEncoder( args, src_dict, embed_tokens ) else: return TransformerEncoder(args, src_dict, embed_tokens)
def _copy_keys(args, cls, prefix, seen): """ copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim """ cfg = cls() for fld in fields(cls): # for all the fields in the DC, find the fields (e.g. embed_dim) # in the namespace with the prefix (e.g. decoder) # and set it on the dc. args_key = f"{prefix}_{fld.name}" if safe_hasattr(args, args_key): seen.add(args_key) setattr(cfg, fld.name, safe_getattr(args, args_key)) if safe_hasattr(args, fld.name): seen.add(fld.name) setattr(cfg, fld.name, safe_getattr(args, fld.name)) return cfg
def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None, ): self.num_tasks = args.decoder_tasks self.num_layers = args.decoder_layers self.total_num_heads = args.total_decoder_attention_heads self.num_heads = args.decoder_attention_heads self.select_strategy = args.attn_head_select_strategy super().__init__( args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn, output_projection=output_projection ) self.self_attn_head_selector = None self.enc_attn_head_selector = None if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select: self.self_attn_head_selector = AttnHeadSelector( self.num_tasks, self.num_layers, self.total_num_heads, self.num_heads, self.select_strategy ) if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select: self.enc_attn_head_selector = AttnHeadSelector( self.num_tasks, self.num_layers, self.total_num_heads, self.num_heads, self.select_strategy ) self.task_ids = None self.layers = nn.ModuleList( [ self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers) ] )
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) if not safe_hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample encoder = LinformerEncoder(args, task.source_dictionary) return cls(args, encoder)
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not safe_hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample logger.info(args) encoder = MaskedLMEncoder(args, task.dictionary) return cls(args, encoder)
def build_model(cls, args, task): """Build a new model instance.""" from omegaconf import OmegaConf if OmegaConf.is_config(args): OmegaConf.set_struct(args, False) # make sure all arguments are present base_architecture(args) if not safe_hasattr(args, "max_positions"): if not safe_hasattr(args, "tokens_per_sample"): args.tokens_per_sample = task.max_positions() args.max_positions = args.tokens_per_sample encoder = RobertaEncoder(args, task.source_dictionary) if OmegaConf.is_config(args): OmegaConf.set_struct(args, True) return cls(args, encoder)
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): if is_encoder: if safe_hasattr( args, "encoder_latent_layer") and args.encoder_latent_layer: return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) else: return TransformerEncoder(args, lang_dict, embed_tokens) else: if safe_hasattr( args, "decoder_latent_layer") and args.decoder_latent_layer: return LatentTransformerDecoder(args, lang_dict, embed_tokens, num_logits=len(langs)) else: return TransformerDecoder(args, lang_dict, embed_tokens)
def from_namespace(cls, args): if args is None: return None if not isinstance(args, cls): seen = set() config = cls() # currently, we can go generically from DC fields to args hierarchically # but we can't easily deconstruct a flat namespace to a hierarchical # DC. Mostly because we could have a sub-dc called `decoder-foo` that should not # go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple # for now. for fld in fields(cls): # concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields # and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()` if fld.name == "decoder": if safe_hasattr(args, "decoder"): # in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC seen.add("decoder") config.decoder = DecoderConfig(**args.decoder) else: config.decoder = cls._copy_keys( args, DecoderConfig, "decoder", seen) elif fld.name == "encoder": # same but for encoder if safe_hasattr(args, "encoder"): seen.add("encoder") config.encoder = EncDecBaseConfig(**args.encoder) else: config.encoder = cls._copy_keys( args, EncDecBaseConfig, "encoder", seen) elif fld.name == "quant_noise": # same but for quant_noise if safe_hasattr(args, "quant_noise"): seen.add("quant_noise") config.quant_noise = QuantNoiseConfig( **args.quant_noise) else: config.quant_noise = cls._copy_keys( args, QuantNoiseConfig, "quant_noise", seen) elif safe_hasattr(args, fld.name): # if it's not a structure field, it's just a normal field, copy it over seen.add(fld.name) setattr(config, fld.name, safe_getattr(args, fld.name)) # we got all the fields defined in the dataclass, but # the argparse namespace might have extra args for two reasons: # - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this # - some places expect args to be there but never define them args_dict = (args._asdict() if safe_hasattr(args, "_asdict") else vars(args) if safe_hasattr(args, "__dict__") else {} ) # namedtupled doesn't have __dict__ :-/ for key, value in args_dict.items(): if key not in seen: setattr(config, key, value) return config else: return args
def build_decoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_target_positions = 1024 dec_emb = nn.Embedding( len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad() ) decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb) if getattr(args, "load_pretrained_mbart_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "no_final_norm_decoder", False): decoder.layer_norm = None for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_decoder_params" ) and need_finetuning( args.finetune_mbart_decoder_params, k ): p.requires_grad = True else: p.requires_grad = False compute_cross_attentive_loss = ( True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False ) cross_attentive_loss_without_norm = getattr( args, "attentive_cost_without_normalize", False ) cross_attentive_loss_reverse = ( False # getattr(args, "attentive_cost_reverse", False) ) decoder = TransformerMultiInputDecoder( dictionary=task.target_dictionary, spch_decoder=decoder, text_decoder=decoder, compute_cross_attentive_loss=compute_cross_attentive_loss, cross_attentive_loss_with_norm=True if not cross_attentive_loss_without_norm else False, cross_attentive_loss_reverse=cross_attentive_loss_reverse, ) return decoder
def __init__(self, args): super().__init__(None) self.w2v_encoder = Wav2VecEncoder(args) encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim # Projection + 8x shrinking self.adaptor = Conv1dAdaptor(encoder_out_dim, args.decoder_embed_dim, n_layers=args.adaptor_n_layers, kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride, add_layernorm=args.adaptor_layernorm) for k, p in self.w2v_encoder.w2v_model.named_parameters(): # Freeze pretrained models by default if safe_hasattr(args, 'finetune_w2v_params' ) and XMTransformerModel.finetune_params( args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False
def build_encoder(cls, args): if safe_hasattr( args, "encoder_attn_head_select") and args.encoder_attn_head_select: encoder = HeadSelectionS2TTransformerEncoder(args) else: encoder = S2TTransformerEncoder(args) pretraining_path = getattr(args, "load_pretrained_encoder_from", None) if pretraining_path is not None: if not Path(pretraining_path).exists(): logger.warning( f"skipped pretraining because {pretraining_path} does not exist" ) else: encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=pretraining_path) logger.info( f"loaded pretrained encoder from: {pretraining_path}") return encoder
def build_decoder(cls, args, task, embed_tokens): _args = copy.deepcopy(args) _args.dropout = args.decoder_dropout _args.attention_dropout = args.decoder_attention_dropout _args.activation_dropout = args.decoder_activation_dropout _args.max_target_positions = 1024 decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr(args, 'finetune_decoder_params' ) and XMTransformerModel.finetune_params( args.finetune_decoder_params, k): p.requires_grad = True else: p.requires_grad = False return decoder
def base_lm_architecture(args): # backward compatibility for older model checkpoints if safe_hasattr(args, "no_tie_adaptive_proj"): # previous models defined --no-tie-adaptive-proj, so use the existence of # that option to determine if this is an "old" model checkpoint args.no_decoder_final_norm = True # old models always set this to True if args.no_tie_adaptive_proj is False: args.tie_adaptive_proj = True if safe_hasattr(args, "decoder_final_norm"): args.no_decoder_final_norm = not args.decoder_final_norm args.dropout = safe_getattr(args, "dropout", 0.1) args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0) args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512) args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 2048) args.decoder_layers = safe_getattr(args, "decoder_layers", 6) args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8) args.adaptive_softmax_cutoff = safe_getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = safe_getattr(args, "adaptive_softmax_factor", 4) args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", False) args.activation_fn = safe_getattr(args, "activation_fn", "relu") args.decoder_layerdrop = safe_getattr(args, "decoder_layerdrop", 0) args.decoder_layers_to_keep = safe_getattr(args, "decoder_layers_to_keep", None) args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0) args.base_layers = safe_getattr(args, "base_layers", 0) args.base_sublayers = safe_getattr(args, "base_sublayers", 1) args.base_shuffle = safe_getattr(args, "base_shuffle", False) args.add_bos_token = safe_getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = safe_getattr( args, "no_token_positional_embeddings", False ) args.share_decoder_input_output_embed = safe_getattr( args, "share_decoder_input_output_embed", False ) args.character_embeddings = safe_getattr(args, "character_embeddings", False) args.decoder_output_dim = safe_getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim) # Model training is not stable without this args.decoder_normalize_before = True args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", False) args.adaptive_input = safe_getattr(args, "adaptive_input", False) args.adaptive_input_factor = safe_getattr(args, "adaptive_input_factor", 4) args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", None) args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", False) args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", False) args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", False) args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) args.offload_activations = safe_getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True
def decoder_latent_layer(self): return ( safe_hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer )
def main(args): assert args.path is not None, "--path required for generation!" assert ( not args.sampling or args.nbest == args.beam ), "--sampling requires --nbest to be equal to --beam" assert ( args.replace_unk is None or args.raw_text ), "--replace-unk requires a raw text dataset (--raw-text)" args.beam = 1 utils.import_user_module(args) if args.max_tokens is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print("| loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(":"), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_positions=utils.resolve_max_positions( task.max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) num_sentences = 0 source_sentences = [] shard_id = 0 all_avg_pool = None encoder_has_langtok = ( safe_hasattr(task.args, "encoder_langtok") and task.args.encoder_langtok is not None and safe_hasattr(task.args, "lang_tok_replacing_bos_eos") and not task.args.lang_tok_replacing_bos_eos ) with progress_bar.build_progress_bar(args, itr) as t: for sample in t: if sample is None: print("Skipping None") continue sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, : args.prefix_size] with torch.no_grad(): avg_pool = get_avg_pool( models, sample, prefix_tokens, src_dict, args.post_process, has_langtok=encoder_has_langtok, ) if all_avg_pool is not None: all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) else: all_avg_pool = avg_pool if not isinstance(sample["id"], list): sample_ids = sample["id"].tolist() else: sample_ids = sample["id"] for i, sample_id in enumerate(sample_ids): # Remove padding src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text( sample_id ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.post_process) else: src_str = "" if not args.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str)) source_sentences.append(f"{sample_id}\t{src_str}") num_sentences += sample["nsentences"] if all_avg_pool.shape[0] >= 1000000: with open( f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w", ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) with open( f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w", ) as sentence_file: sentence_file.writelines(f"{line}\n" for line in source_sentences) all_avg_pool = None source_sentences = [] shard_id += 1 if all_avg_pool is not None: with open( f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w" ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) with open( f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w" ) as sentence_file: sentence_file.writelines(f"{line}\n" for line in source_sentences) return None
def build_encoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_source_positions = 1024 enc_emb = nn.Embedding( len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() ) text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) spch_encoder = Wav2VecEncoderWithAdaptor(args) if getattr(args, "load_pretrained_mbart_from", None): text_encoder = checkpoint_utils.load_pretrained_component_from_model( component=text_encoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "stack_w2v_mbart_encoder", False): assert getattr(args, "share_w2v_text_encoder", False) is False spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): text_encoder.layer_norm = None spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "share_w2v_text_encoder", False): spch_encoder = SharedEncoder( spch_encoder.w2v_encoder, text_encoder, spch_encoder.adaptor, args.shared_w2v_layers, ) for k, p in spch_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_w2v_params" ) and need_finetuning(args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False for k, p in text_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_encoder_params" ) and need_finetuning( args.finetune_mbart_encoder_params, k ): p.requires_grad = True else: p.requires_grad = False cross_attentive_loss_before_last_layer = ( 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 ) encoder = DualInputEncoder( args, spch_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) return encoder
def build_model(cls, args, task): """Build a new model instance.""" from fairseq.tasks.multilingual_translation import MultilingualTranslationTask assert isinstance(task, MultilingualTranslationTask) # make sure all arguments are present in older models base_multilingual_architecture(args) if not safe_hasattr(args, "max_source_positions"): args.max_source_positions = 1024 if not safe_hasattr(args, "max_target_positions"): args.max_target_positions = 1024 src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] if args.share_encoders: args.share_encoder_embeddings = True if args.share_decoders: args.share_decoder_embeddings = True def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb # build shared embeddings (if applicable) shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None if args.share_all_embeddings: if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( dicts=task.dicts, langs=task.langs, embed_dim=args.encoder_embed_dim, build_embedding=build_embedding, pretrained_embed_path=args.encoder_embed_path, ) shared_decoder_embed_tokens = shared_encoder_embed_tokens args.share_decoder_input_output_embed = True else: if args.share_encoder_embeddings: shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( dicts=task.dicts, langs=src_langs, embed_dim=args.encoder_embed_dim, build_embedding=build_embedding, pretrained_embed_path=args.encoder_embed_path, ) if args.share_decoder_embeddings: shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( dicts=task.dicts, langs=tgt_langs, embed_dim=args.decoder_embed_dim, build_embedding=build_embedding, pretrained_embed_path=args.decoder_embed_path, ) # encoders/decoders for each language lang_encoders, lang_decoders = {}, {} def get_encoder(lang): if lang not in lang_encoders: if shared_encoder_embed_tokens is not None: encoder_embed_tokens = shared_encoder_embed_tokens else: encoder_embed_tokens = build_embedding( task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path, ) lang_encoders[lang] = cls._get_module_class( True, args, task.dicts[lang], encoder_embed_tokens, src_langs ) return lang_encoders[lang] def get_decoder(lang): if lang not in lang_decoders: if shared_decoder_embed_tokens is not None: decoder_embed_tokens = shared_decoder_embed_tokens else: decoder_embed_tokens = build_embedding( task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path, ) lang_decoders[lang] = cls._get_module_class( False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs ) return lang_decoders[lang] # shared encoders/decoders (if applicable) shared_encoder, shared_decoder = None, None if args.share_encoders: shared_encoder = get_encoder(src_langs[0]) if args.share_decoders: shared_decoder = get_decoder(tgt_langs[0]) encoders, decoders = OrderedDict(), OrderedDict() for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): encoders[lang_pair] = ( shared_encoder if shared_encoder is not None else get_encoder(src) ) decoders[lang_pair] = ( shared_decoder if shared_decoder is not None else get_decoder(tgt) ) return MultilingualTransformerModel(encoders, decoders)
def encoder_latent_layer(self): return ( safe_hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer )