Esempio n. 1
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure all arguments are present in older models
        base_lm_architecture(args)

        if safe_hasattr(args, "max_target_positions") and not safe_hasattr(
            args, "tokens_per_sample"
        ):
            args.tokens_per_sample = args.max_target_positions

        decoder = FConvDecoder(
            dictionary=task.target_dictionary,
            embed_dim=args.decoder_embed_dim,
            convolutions=eval(args.decoder_layers),
            out_embed_dim=args.decoder_embed_dim,
            attention=eval(args.decoder_attention),
            dropout=args.dropout,
            max_positions=args.tokens_per_sample,
            share_embed=False,
            positional_embeddings=False,
            adaptive_softmax_cutoff=(
                utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
                if args.criterion == "adaptive_loss"
                else None
            ),
            adaptive_softmax_dropout=args.adaptive_softmax_dropout,
        )
        return FConvLanguageModel(decoder)
Esempio n. 2
0
 def build_decoder(cls, args, task, embed_tokens):
     if (safe_hasattr(args, "decoder_self_attn_head_select")
             and args.decoder_self_attn_head_select) or (
                 safe_hasattr(args, "dec_enc_attn_head_select")
                 and args.dec_enc_attn_head_select):
         return HeadSelectionTransformerDecoderScriptable(
             args, task.target_dictionary, embed_tokens)
     else:
         return TransformerDecoderScriptable(args, task.target_dictionary,
                                             embed_tokens)
Esempio n. 3
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not safe_hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not safe_hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise RuntimeError(
                    "--share-all-embeddings requires a joined dictionary"
                )
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise RuntimeError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise RuntimeError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )

        encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens)
        decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
        return LightConvModel(encoder, decoder)
 def build_encoder(cls, args, src_dict, embed_tokens):
     if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
         return HeadSelectionTransformerEncoder(
             args, src_dict, embed_tokens
         )
     else:
         return TransformerEncoder(args, src_dict, embed_tokens)
Esempio n. 5
0
 def _copy_keys(args, cls, prefix, seen):
     """
     copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim
     """
     cfg = cls()
     for fld in fields(cls):
         # for all the fields in the DC, find the fields (e.g. embed_dim)
         # in the namespace with the prefix (e.g. decoder)
         # and set it on the dc.
         args_key = f"{prefix}_{fld.name}"
         if safe_hasattr(args, args_key):
             seen.add(args_key)
             setattr(cfg, fld.name, safe_getattr(args, args_key))
         if safe_hasattr(args, fld.name):
             seen.add(fld.name)
             setattr(cfg, fld.name, safe_getattr(args, fld.name))
     return cfg
 def __init__(
     self,
     args,
     dictionary,
     embed_tokens,
     no_encoder_attn=False,
     output_projection=None,
 ):
     self.num_tasks = args.decoder_tasks
     self.num_layers = args.decoder_layers
     self.total_num_heads = args.total_decoder_attention_heads
     self.num_heads = args.decoder_attention_heads
     self.select_strategy = args.attn_head_select_strategy
     super().__init__(
         args, dictionary, embed_tokens,
         no_encoder_attn=no_encoder_attn,
         output_projection=output_projection
     )
     self.self_attn_head_selector = None
     self.enc_attn_head_selector = None
     if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select:
         self.self_attn_head_selector = AttnHeadSelector(
             self.num_tasks,
             self.num_layers,
             self.total_num_heads,
             self.num_heads,
             self.select_strategy
         )
     if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select:
         self.enc_attn_head_selector = AttnHeadSelector(
             self.num_tasks,
             self.num_layers,
             self.total_num_heads,
             self.num_heads,
             self.select_strategy
         )
     self.task_ids = None
     self.layers = nn.ModuleList(
         [
             self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers)
         ]
     )
Esempio n. 7
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not safe_hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder = LinformerEncoder(args, task.source_dictionary)
        return cls(args, encoder)
Esempio n. 8
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure all arguments are present in older models
        base_architecture(args)

        if not safe_hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        logger.info(args)

        encoder = MaskedLMEncoder(args, task.dictionary)
        return cls(args, encoder)
Esempio n. 9
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        from omegaconf import OmegaConf

        if OmegaConf.is_config(args):
            OmegaConf.set_struct(args, False)

        # make sure all arguments are present
        base_architecture(args)

        if not safe_hasattr(args, "max_positions"):
            if not safe_hasattr(args, "tokens_per_sample"):
                args.tokens_per_sample = task.max_positions()
            args.max_positions = args.tokens_per_sample

        encoder = RobertaEncoder(args, task.source_dictionary)

        if OmegaConf.is_config(args):
            OmegaConf.set_struct(args, True)

        return cls(args, encoder)
Esempio n. 10
0
 def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens,
                       langs):
     if is_encoder:
         if safe_hasattr(
                 args,
                 "encoder_latent_layer") and args.encoder_latent_layer:
             return LatentTransformerEncoder(args,
                                             lang_dict,
                                             embed_tokens,
                                             num_logits=len(langs))
         else:
             return TransformerEncoder(args, lang_dict, embed_tokens)
     else:
         if safe_hasattr(
                 args,
                 "decoder_latent_layer") and args.decoder_latent_layer:
             return LatentTransformerDecoder(args,
                                             lang_dict,
                                             embed_tokens,
                                             num_logits=len(langs))
         else:
             return TransformerDecoder(args, lang_dict, embed_tokens)
Esempio n. 11
0
 def from_namespace(cls, args):
     if args is None:
         return None
     if not isinstance(args, cls):
         seen = set()
         config = cls()
         # currently, we can go generically from DC fields to args hierarchically
         # but we can't easily deconstruct a flat namespace to a hierarchical
         # DC. Mostly because we could have a sub-dc called `decoder-foo` that should not
         # go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple
         # for now.
         for fld in fields(cls):
             # concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields
             # and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()`
             if fld.name == "decoder":
                 if safe_hasattr(args, "decoder"):
                     #  in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC
                     seen.add("decoder")
                     config.decoder = DecoderConfig(**args.decoder)
                 else:
                     config.decoder = cls._copy_keys(
                         args, DecoderConfig, "decoder", seen)
             elif fld.name == "encoder":
                 # same but for encoder
                 if safe_hasattr(args, "encoder"):
                     seen.add("encoder")
                     config.encoder = EncDecBaseConfig(**args.encoder)
                 else:
                     config.encoder = cls._copy_keys(
                         args, EncDecBaseConfig, "encoder", seen)
             elif fld.name == "quant_noise":
                 # same but for quant_noise
                 if safe_hasattr(args, "quant_noise"):
                     seen.add("quant_noise")
                     config.quant_noise = QuantNoiseConfig(
                         **args.quant_noise)
                 else:
                     config.quant_noise = cls._copy_keys(
                         args, QuantNoiseConfig, "quant_noise", seen)
             elif safe_hasattr(args, fld.name):
                 # if it's not a structure field, it's just a normal field, copy it over
                 seen.add(fld.name)
                 setattr(config, fld.name, safe_getattr(args, fld.name))
         # we got all the fields defined in the dataclass, but
         # the argparse namespace might have extra args for two reasons:
         #   - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this
         #   - some places expect args to be there but never define them
         args_dict = (args._asdict() if safe_hasattr(args, "_asdict") else
                      vars(args) if safe_hasattr(args, "__dict__") else {}
                      )  # namedtupled doesn't have __dict__ :-/
         for key, value in args_dict.items():
             if key not in seen:
                 setattr(config, key, value)
         return config
     else:
         return args
Esempio n. 12
0
    def build_decoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_target_positions = 1024
        dec_emb = nn.Embedding(
            len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
        )
        decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
        if getattr(args, "load_pretrained_mbart_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "no_final_norm_decoder", False):
            decoder.layer_norm = None
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_decoder_params"
            ) and need_finetuning(
                args.finetune_mbart_decoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False

        compute_cross_attentive_loss = (
            True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
        )
        cross_attentive_loss_without_norm = getattr(
            args, "attentive_cost_without_normalize", False
        )
        cross_attentive_loss_reverse = (
            False  # getattr(args, "attentive_cost_reverse", False)
        )
        decoder = TransformerMultiInputDecoder(
            dictionary=task.target_dictionary,
            spch_decoder=decoder,
            text_decoder=decoder,
            compute_cross_attentive_loss=compute_cross_attentive_loss,
            cross_attentive_loss_with_norm=True
            if not cross_attentive_loss_without_norm
            else False,
            cross_attentive_loss_reverse=cross_attentive_loss_reverse,
        )
        return decoder
Esempio n. 13
0
 def __init__(self, args):
     super().__init__(None)
     self.w2v_encoder = Wav2VecEncoder(args)
     encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim
     # Projection + 8x shrinking
     self.adaptor = Conv1dAdaptor(encoder_out_dim,
                                  args.decoder_embed_dim,
                                  n_layers=args.adaptor_n_layers,
                                  kernel_size=args.adaptor_kernel_size,
                                  stride=args.adaptor_stride,
                                  add_layernorm=args.adaptor_layernorm)
     for k, p in self.w2v_encoder.w2v_model.named_parameters():
         # Freeze pretrained models by default
         if safe_hasattr(args, 'finetune_w2v_params'
                         ) and XMTransformerModel.finetune_params(
                             args.finetune_w2v_params, k):
             p.requires_grad = True
         else:
             p.requires_grad = False
Esempio n. 14
0
 def build_encoder(cls, args):
     if safe_hasattr(
             args,
             "encoder_attn_head_select") and args.encoder_attn_head_select:
         encoder = HeadSelectionS2TTransformerEncoder(args)
     else:
         encoder = S2TTransformerEncoder(args)
     pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
     if pretraining_path is not None:
         if not Path(pretraining_path).exists():
             logger.warning(
                 f"skipped pretraining because {pretraining_path} does not exist"
             )
         else:
             encoder = checkpoint_utils.load_pretrained_component_from_model(
                 component=encoder, checkpoint=pretraining_path)
             logger.info(
                 f"loaded pretrained encoder from: {pretraining_path}")
     return encoder
Esempio n. 15
0
    def build_decoder(cls, args, task, embed_tokens):
        _args = copy.deepcopy(args)
        _args.dropout = args.decoder_dropout
        _args.attention_dropout = args.decoder_attention_dropout
        _args.activation_dropout = args.decoder_activation_dropout
        _args.max_target_positions = 1024

        decoder = TransformerDecoder(_args, task.target_dictionary,
                                     embed_tokens)
        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder,
                checkpoint=args.load_pretrained_decoder_from)
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(args, 'finetune_decoder_params'
                            ) and XMTransformerModel.finetune_params(
                                args.finetune_decoder_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        return decoder
Esempio n. 16
0
def base_lm_architecture(args):
    # backward compatibility for older model checkpoints
    if safe_hasattr(args, "no_tie_adaptive_proj"):
        # previous models defined --no-tie-adaptive-proj, so use the existence of
        # that option to determine if this is an "old" model checkpoint
        args.no_decoder_final_norm = True  # old models always set this to True
        if args.no_tie_adaptive_proj is False:
            args.tie_adaptive_proj = True
    if safe_hasattr(args, "decoder_final_norm"):
        args.no_decoder_final_norm = not args.decoder_final_norm

    args.dropout = safe_getattr(args, "dropout", 0.1)
    args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0)

    args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512)
    args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 2048)
    args.decoder_layers = safe_getattr(args, "decoder_layers", 6)
    args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8)
    args.adaptive_softmax_cutoff = safe_getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0)
    args.adaptive_softmax_factor = safe_getattr(args, "adaptive_softmax_factor", 4)
    args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", False)
    args.activation_fn = safe_getattr(args, "activation_fn", "relu")

    args.decoder_layerdrop = safe_getattr(args, "decoder_layerdrop", 0)
    args.decoder_layers_to_keep = safe_getattr(args, "decoder_layers_to_keep", None)
    args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0)
    args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8)
    args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0)

    args.base_layers = safe_getattr(args, "base_layers", 0)
    args.base_sublayers = safe_getattr(args, "base_sublayers", 1)
    args.base_shuffle = safe_getattr(args, "base_shuffle", False)

    args.add_bos_token = safe_getattr(args, "add_bos_token", False)
    args.no_token_positional_embeddings = safe_getattr(
        args, "no_token_positional_embeddings", False
    )
    args.share_decoder_input_output_embed = safe_getattr(
        args, "share_decoder_input_output_embed", False
    )
    args.character_embeddings = safe_getattr(args, "character_embeddings", False)

    args.decoder_output_dim = safe_getattr(
        args, "decoder_output_dim", args.decoder_embed_dim
    )
    args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim)

    # Model training is not stable without this
    args.decoder_normalize_before = True
    args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", False)

    args.adaptive_input = safe_getattr(args, "adaptive_input", False)
    args.adaptive_input_factor = safe_getattr(args, "adaptive_input_factor", 4)
    args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", None)

    args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", False)
    args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", False)

    args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", False)
    args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
    args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False)
    args.offload_activations = safe_getattr(args, "offload_activations", False)
    if args.offload_activations:
        args.checkpoint_activations = True
Esempio n. 17
0
 def decoder_latent_layer(self):
     return (
         safe_hasattr(self.args, "decoder_latent_layer")
         and self.args.decoder_latent_layer
     )
Esempio n. 18
0
def main(args):
    assert args.path is not None, "--path required for generation!"
    assert (
        not args.sampling or args.nbest == args.beam
    ), "--sampling requires --nbest to be equal to --beam"
    assert (
        args.replace_unk is None or args.raw_text
    ), "--replace-unk requires a raw text dataset (--raw-text)"

    args.beam = 1
    utils.import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 12000
    print(args)
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(":"),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    num_sentences = 0
    source_sentences = []
    shard_id = 0
    all_avg_pool = None
    encoder_has_langtok = (
        safe_hasattr(task.args, "encoder_langtok")
        and task.args.encoder_langtok is not None
        and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
        and not task.args.lang_tok_replacing_bos_eos
    )
    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            if sample is None:
                print("Skipping None")
                continue
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            with torch.no_grad():
                avg_pool = get_avg_pool(
                    models,
                    sample,
                    prefix_tokens,
                    src_dict,
                    args.post_process,
                    has_langtok=encoder_has_langtok,
                )
                if all_avg_pool is not None:
                    all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
                else:
                    all_avg_pool = avg_pool

            if not isinstance(sample["id"], list):
                sample_ids = sample["id"].tolist()
            else:
                sample_ids = sample["id"]
            for i, sample_id in enumerate(sample_ids):
                # Remove padding
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
                )

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(
                        sample_id
                    )
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.post_process)
                    else:
                        src_str = ""

                if not args.quiet:
                    if src_dict is not None:
                        print("S-{}\t{}".format(sample_id, src_str))

                source_sentences.append(f"{sample_id}\t{src_str}")

            num_sentences += sample["nsentences"]
            if all_avg_pool.shape[0] >= 1000000:
                with open(
                    f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
                    "w",
                ) as avg_pool_file:
                    all_avg_pool.tofile(avg_pool_file)
                with open(
                    f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
                    "w",
                ) as sentence_file:
                    sentence_file.writelines(f"{line}\n" for line in source_sentences)
                all_avg_pool = None
                source_sentences = []
                shard_id += 1

    if all_avg_pool is not None:
        with open(
            f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
        ) as avg_pool_file:
            all_avg_pool.tofile(avg_pool_file)
        with open(
            f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
        ) as sentence_file:
            sentence_file.writelines(f"{line}\n" for line in source_sentences)
    return None
Esempio n. 19
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder
Esempio n. 20
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        from fairseq.tasks.multilingual_translation import MultilingualTranslationTask

        assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_multilingual_architecture(args)

        if not safe_hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not safe_hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
        tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=args.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.share_encoder_embeddings:
                shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                    dicts=task.dicts,
                    langs=src_langs,
                    embed_dim=args.encoder_embed_dim,
                    build_embedding=build_embedding,
                    pretrained_embed_path=args.encoder_embed_path,
                )
            if args.share_decoder_embeddings:
                shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                    dicts=task.dicts,
                    langs=tgt_langs,
                    embed_dim=args.decoder_embed_dim,
                    build_embedding=build_embedding,
                    pretrained_embed_path=args.decoder_embed_path,
                )

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang],
                        args.encoder_embed_dim,
                        args.encoder_embed_path,
                    )
                lang_encoders[lang] = cls._get_module_class(
                    True, args, task.dicts[lang], encoder_embed_tokens, src_langs
                )
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang],
                        args.decoder_embed_dim,
                        args.decoder_embed_path,
                    )
                lang_decoders[lang] = cls._get_module_class(
                    False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
                )
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
            encoders[lang_pair] = (
                shared_encoder if shared_encoder is not None else get_encoder(src)
            )
            decoders[lang_pair] = (
                shared_decoder if shared_decoder is not None else get_decoder(tgt)
            )

        return MultilingualTransformerModel(encoders, decoders)
Esempio n. 21
0
 def encoder_latent_layer(self):
     return (
         safe_hasattr(self.args, "encoder_latent_layer")
         and self.args.encoder_latent_layer
     )