def __init__(self, cfg: MTBottleneckModelConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)

        self.model_type: str = cfg.get("model_type", "nll")
        self.min_logv: float = cfg.get("min_logv", -6)
        self.latent_size: int = cfg.get("latent_size", -1)
        self.non_recon_warmup_batches: int = cfg.get("non_recon_warmup_batches", 200000)
        self.recon_per_token: bool = cfg.get("recon_per_token", True)
        self.log_timing: bool = cfg.get("log_timing", True)

        # if True, translation uses the mean of latent for VAE and MIM
        self.deterministic_translate = True

        # latent_size -1 will take value of encoder.hidden_size
        if self.latent_size < 0:
            self.latent_size = self.encoder.hidden_size

        if not self.recon_per_token:
            # disable reduction for train and eval loss
            self.eval_loss_fn = NLLLoss(ignore_index=self.decoder_tokenizer.pad_id, reduction='none')
            self.loss_fn._per_token_reduction = False

        if self.model_type not in ["nll", "mim", "vae"]:
            raise ValueError(f"Unknown model_type = {self.model_type}")

        # project bridge dimension back to decoder hidden dimensions
        self.latent2hidden = build_linear_or_identity(self.latent_size, self.decoder.hidden_size)

        # project dimension of encoder hidden to latent dimension
        self.hidden2latent_mean = build_linear_or_identity(self.encoder.hidden_size, self.latent_size)

        # MIM or VAE
        if self.model_type != "nll":
            # for probabilistic latent variable models we also need variance
            self.hidden2latent_logv = build_linear_or_identity(self.encoder.hidden_size, self.latent_size)
Exemplo n.º 2
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language = cfg.get("src_language", None)
        self.tgt_language = cfg.get("tgt_language", None)

        self.multilingual = cfg.get("multilingual", False)
        self.multilingual_ids = []

        self.encoder_tokenizer_library = cfg.encoder_tokenizer.get(
            'library', 'yttm')
        self.decoder_tokenizer_library = cfg.decoder_tokenizer.get(
            'library', 'yttm')

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=self.encoder_tokenizer_library,
            encoder_tokenizer_model=cfg.encoder_tokenizer.get(
                'tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.encoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(
                cfg.encoder, 'model_name') else None,
            encoder_r2l=cfg.encoder_tokenizer.get('r2l', False),
            decoder_tokenizer_library=self.decoder_tokenizer_library,
            encoder_tokenizer_vocab_file=cfg.encoder_tokenizer.get(
                'vocab_file', None),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.decoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(
                cfg.decoder, 'model_name') else None,
            decoder_r2l=cfg.decoder_tokenizer.get('r2l', False),
        )

        if self.multilingual:
            if isinstance(self.src_language, ListConfig) and isinstance(
                    self.tgt_language, ListConfig):
                raise ValueError(
                    "cfg.src_language and cfg.tgt_language cannot both be lists. We only support many-to-one or one-to-many multilingual models."
                )
            elif isinstance(self.src_language, ListConfig):
                for lng in self.src_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            elif isinstance(self.tgt_language, ListConfig):
                for lng in self.tgt_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            else:
                raise ValueError(
                    "Expect either cfg.src_language or cfg.tgt_language to be a list when multilingual=True."
                )

            if isinstance(self.src_language, ListConfig):
                self.tgt_language = [self.tgt_language] * len(
                    self.src_language)
            else:
                self.src_language = [self.src_language] * len(
                    self.tgt_language)

            self.source_processor_list = []
            self.target_processor_list = []
            for src_lng, tgt_lng in zip(self.src_language, self.tgt_language):
                src_prcsr, tgt_prscr = self.setup_pre_and_post_processing_utils(
                    src_lng, tgt_lng)
                self.source_processor_list.append(src_prcsr)
                self.target_processor_list.append(tgt_prscr)

        else:
            # After this call, the model will have  self.source_processor and self.target_processor objects
            self.setup_pre_and_post_processing_utils(self.src_language,
                                                     self.tgt_language)
            self.multilingual_ids = [None]

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        checkpoint_file = encoder_cfg_dict.pop('checkpoint_file', None)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
            checkpoint_file=checkpoint_file,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=decoder_cfg_dict,
            encoder=False,
            pre_ln_final_layer_norm=decoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = NLLLoss(ignore_index=self.decoder_tokenizer.pad_id)
    def __init__(self, cfg: MTBottleneckModelConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)

        recon_per_token: bool = True

        self.model_type: str = cfg.get("model_type", "seq2seq-br")
        self.min_logv: float = cfg.get("min_logv", -6)
        self.ortho_loss_coef: float = cfg.get("ortho_loss_coef", 0.0)
        self.att_bridge_size: int = cfg.get("att_bridge_size", 512)
        self.att_bridge_k: int = cfg.get("att_bridge_k", 16)
        self.att_bridge_inner_size: int = cfg.get("att_bridge_inner_size",
                                                  1024)
        self.non_recon_warmup_batches: int = cfg.get(
            "non_recon_warmup_batches", 200000)
        self.recon_per_token: bool = cfg.get("recon_per_token", True)

        # TODO: add support in label smoothing for per-sample reconstruction loss
        if not self.recon_per_token:
            loss_fn = NLLLoss(
                ignore_index=self.decoder_tokenizer.pad_id,
                reduction='none',
            )
            self.loss_fn = self.eval_loss_fn = loss_fn

        if self.model_type not in [
                "seq2seq", "seq2seq-br", "seq2seq-mim", "seq2seq-vae"
        ]:
            raise ValueError("Unknown model_type = {model_type}".format(
                model_type=self.model_type, ))

        if self.model_type != "seq2seq":
            # project bridge dimension back to decoder hidden dimensions
            if self.att_bridge_size != self.encoder.hidden_size:
                self.latent2hidden = torch.nn.Linear(self.att_bridge_size,
                                                     self.encoder.hidden_size)
            else:
                self.latent2hidden = torch.nn.Identity()

            self.att_bridge = AttentionBridge(
                hidden_size=self.encoder.hidden_size,
                k=self.att_bridge_k,
                bridge_size=self.att_bridge_size,
            )

            # project dimension of encoder hidden to bridge dimension
            if self.encoder.hidden_size != self.att_bridge_size:
                self.hidden2latent_mean = torch.nn.Linear(
                    self.encoder.hidden_size, self.att_bridge_size)
            else:
                self.hidden2latent_mean = torch.nn.Identity()

            # for probabilistic latent variable models we also need variance
            if self.model_type in ["seq2seq-mim", "seq2seq-vae"]:
                if self.encoder.hidden_size != self.att_bridge_size:
                    self.hidden2latent_logv = torch.nn.Linear(
                        self.encoder.hidden_size, self.att_bridge_size)
                else:
                    self.hidden2latent_logv = torch.nn.Identity()
        else:
            # seq2seq
            self.latent2hidden = torch.nn.Identity()