Exemple #1
0
    def __init__(
        self,
        vocab: Vocabulary,
        embed: TextFieldEmbedder,
        encoder_size: int,
        decoder_size: int,
        num_layers: int,
        beam_size: int,
        max_decoding_steps: int,
        use_bleu: bool = True,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super().__init__(vocab)

        self.START, self.END = self.vocab.get_token_index(
            START_SYMBOL), self.vocab.get_token_index(END_SYMBOL)
        self.OOV = self.vocab.get_token_index(self.vocab._oov_token)  # pylint: disable=protected-access
        self.PAD = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
        self.COPY = self.vocab.get_token_index("@@COPY@@")
        self.KEEP = self.vocab.get_token_index("@@KEEP@@")
        self.DROP = self.vocab.get_token_index("@@DROP@@")

        self.SYMBOL = (self.START, self.END, self.PAD, self.KEEP, self.DROP)
        self.vocab_size = vocab.get_vocab_size()
        self.EMB = embed

        self.emb_size = self.EMB.token_embedder_tokens.output_dim
        self.encoder_size, self.decoder_size = encoder_size, decoder_size
        self.FACT_ENCODER = FeedForward(3 * self.emb_size, 1, encoder_size,
                                        nn.Tanh())
        self.ATTN = AdditiveAttention(encoder_size + decoder_size,
                                      encoder_size)
        self.COPY_ATTN = AdditiveAttention(decoder_size, encoder_size)
        module = nn.LSTM(self.emb_size,
                         encoder_size // 2,
                         num_layers,
                         bidirectional=True,
                         batch_first=True)
        self.BUFFER = PytorchSeq2SeqWrapper(
            module)  # BiLSTM to encode draft text
        self.STREAM = nn.LSTMCell(2 * encoder_size,
                                  decoder_size)  # Store revised text

        self.BEAM = BeamSearch(self.END,
                               max_steps=max_decoding_steps,
                               beam_size=beam_size)

        self.U = nn.Sequential(nn.Linear(2 * encoder_size, decoder_size),
                               nn.Tanh())
        self.ADD = nn.Sequential(nn.Linear(self.emb_size, encoder_size),
                                 nn.Tanh())

        self.P = nn.Sequential(
            nn.Linear(encoder_size + decoder_size, decoder_size), nn.Tanh())
        self.W = nn.Linear(decoder_size, self.vocab_size)
        self.G = nn.Sequential(nn.Linear(decoder_size, 1), nn.Sigmoid())

        initializer(self)
        self._bleu = BLEU(
            exclude_indices=set(self.SYMBOL)) if use_bleu else None
Exemple #2
0
def build_model(vocab: Vocabulary) -> Model: 
    print("Building the model")
    vocab_size_s = vocab.get_vocab_size("source_tokens")
    vocab_size_t = vocab.get_vocab_size("target_tokens") 
    
    bleu = BLEU(exclude_indices = {0,2,3})

    source_text_embedder = BasicTextFieldEmbedder({"source_tokens": Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size_s)})
    encoder = PytorchTransformer(input_dim=embedding_dim, num_layers=num_layers ,positional_encoding="sinusoidal", 
                            feedforward_hidden_dim=dff, num_attention_heads=num_head, positional_embedding_size = embedding_dim, dropout_prob = dropout)

    
    # target_text_embedder = BasicTextFieldEmbedder({"target_tokens":Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size_t)})
    target_text_embedder = Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size_t)
    decoder_net = StackedSelfAttentionDecoderNet(decoding_dim=embedding_dim, target_embedding_dim=embedding_dim, 
                                feedforward_hidden_dim=dff, num_layers=num_layers, num_attention_heads=num_head, dropout_prob = dropout)
    decoder_net.decodes_parallel=True
    decoder = AutoRegressiveSeqDecoder(
        vocab, decoder_net, max_len, target_text_embedder, 
        target_namespace="target_tokens", tensor_based_metric=bleu, scheduled_sampling_ratio=0.0)
    
    if args.pseudo:
        decoder = PseudoAutoRegressiveSeqDecoder(vocab, decoder_net, max_len, target_text_embedder, target_namespace="target_tokens", tensor_based_metric=bleu, scheduled_sampling_ratio=0.0, decoder_lin_emb = args.dec)
        return PseudoComposedSeq2Seq(vocab, source_text_embedder, encoder, decoder, num_virtual_models = num_virtual_models)
    else:
        decoder = AutoRegressiveSeqDecoder(vocab, decoder_net, max_len, target_text_embedder, target_namespace="target_tokens", tensor_based_metric=bleu, scheduled_sampling_ratio=0.0)
        return ComposedSeq2Seq(vocab, source_text_embedder, encoder, decoder)
Exemple #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 loss_ratio: float = 1.0,
                 remove_sos: bool = True,
                 remove_eos: bool = False,
                 target_namespace: str = "tokens",
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(CTCLayer, self).__init__(vocab, regularizer)
        self.loss_ratio = loss_ratio
        self._remove_sos = remove_sos
        self._remove_eos = remove_eos
        self._target_namespace = target_namespace
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                     self._target_namespace)
        self._loss = CTCLoss(blank=self._pad_index)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        exclude_indices = {self._pad_index, self._end_index, self._start_index}
        self._wer: Metric = WER(exclude_indices=exclude_indices)
        self._bleu: Metric = BLEU(exclude_indices=exclude_indices)
        self._dal: Metric = Average()

        initializer(self)
Exemple #4
0
    def test_multiple_distributed_runs(self):
        predictions = [
            torch.tensor([[1, 0, 0], [1, 1, 0]]),
            torch.tensor([[1, 1, 1]]),
        ]
        gold_targets = [
            torch.tensor([[2, 0, 0], [1, 0, 0]]),
            torch.tensor([[1, 1, 2]]),
        ]

        check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 *
                         (math.log(1) - math.log(3)))
        metric_kwargs = {
            "predictions": predictions,
            "gold_targets": gold_targets
        }
        desired_values = {"BLEU": check}
        run_distributed_test(
            [-1, -1],
            multiple_runs,
            BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0}),
            metric_kwargs,
            desired_values,
            exact=False,
        )
Exemple #5
0
    def __init__(
        self,
        vocab: Vocabulary,
        variational_encoder: VariationalEncoder,
        decoder: Decoder,
        kl_weight: LossWeight,
        temperature: float = 1.0,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super(VAE, self).__init__(vocab)

        self._encoder = variational_encoder
        self._decoder = decoder

        self._latent_dim = variational_encoder.latent_dim

        self._encoder_output_dim = self._encoder.get_encoder_output_dim()

        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
        self._bleu = BLEU(exclude_indices={
            self._pad_index, self._end_index, self._start_index
        })

        self._kl_metric = Average()
        self.kl_weight = kl_weight

        self._temperature = temperature
        initializer(self)
Exemple #6
0
    def __init__(self,
                 vocab,
                 cfg,
                 device,
                 name=None,
                 bi=True,
                 att=True,
                 batch_norm=True,
                 teach_forc_ratio=0.5,
                 patience=3,
                 dropout=0.0,
                 write_idx=3):
        self.device = device
        self.model = Seq2SeqArch(vocab,
                                 cfg,
                                 device,
                                 bi=bi,
                                 att=att,
                                 batch_norm=batch_norm,
                                 teach_forc_ratio=teach_forc_ratio,
                                 dropout=dropout)
        self.vocab = vocab
        self.name = name if name else 'seq2seq'
        self.cfg = cfg
        self.write_idx = write_idx

        # Evaluation metrics
        self.bleu = BLEU(exclude_indices=set([0]))  # Exclude padding

        # Logging variables
        self.train_losses, self.valid_losses, self.test_losses = [], [], []
        self.train_bleu, self.valid_bleu, self.test_bleu = [], [], []
        self.patience = patience  # For early stopping
        self.outputs = []
Exemple #7
0
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str,
                 beam_search: Lazy[BeamSearch] = Lazy(BeamSearch,
                                                      beam_size=3,
                                                      max_steps=50),
                 checkpoint_wrapper: Optional[CheckpointWrapper] = None,
                 weights_path: Optional[Union[str, PathLike]] = None,
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._model_name = model_name
        # We only instantiate this when we need it.
        self._tokenizer: Optional[PretrainedTransformerTokenizer] = None
        self.t5 = T5Module.from_pretrained_module(
            model_name,
            beam_search=beam_search,
            ddp_accelerator=self.ddp_accelerator,
            checkpoint_wrapper=checkpoint_wrapper,
            weights_path=weights_path,
        )

        exclude_indices = {
            self.t5.pad_token_id,
            self.t5.decoder_start_token_id,
            self.t5.eos_token_id,
        }
        self._metrics = [
            ROUGE(exclude_indices=exclude_indices),
            BLEU(exclude_indices=exclude_indices),
        ]
    def test_train(self):
        vocab = {
            '<boundary>': 0,
            '<unk>': 1,
            'have': 2,
            'to': 3,
            'do': 4,
            'it': 5,
            'go': 6,
            'as': 7
        }
        fake_img = torch.rand(1, 3, 255, 255)
        fake_target = torch.tensor([[0, 5, 6, 3, 6, 2, 0]]).long()
        data = {'images': fake_img, 'captions': fake_target}
        encoder = MaskRCNN_Benchmark()
        decoder = UpDownCaptioner(vocab=vocab)
        model = CaptioningModel(encoder=encoder, captioner=decoder)
        optimizer = optim.Adam(params=model.parameters(), lr=4e-4)

        for epoch in range(7):
            loss = train_batch(data, model, optimizer)

            evaluator = BLEU(
                exclude_indices={vocab['<unk>'], vocab['<boundary>']})
            with torch.no_grad():
                eval_batch(data, model, evaluator)
            bleu_score = evaluator.get_metric()['BLEU']

            print(
                'Epoch: {0:2d} | Epoch Loss: {1:7.3f} | BLEU_4 Score: {2:5.2f}'
                .format(epoch, loss, bleu_score))

        self.assertEqual(True, True)
    def test_auto_regressive_seq_decoder_tensor_and_token_based_metric(self):
        # set all seeds to a fixed value (torch, numpy, etc.).
        # this enable a deterministic behavior of the `auto_regressive_seq_decoder`
        # below (i.e., parameter initialization and `encoded_state = torch.randn(..)`)
        prepare_environment(Params({}))

        batch_size, time_steps, decoder_inout_dim = 2, 3, 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        auto_regressive_seq_decoder = AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            10,
            Embedding(vocab.get_vocab_size(), decoder_inout_dim),
            tensor_based_metric=BLEU(),
            token_based_metric=DummyMetric(),
        ).eval()

        encoded_state = torch.randn(batch_size, time_steps, decoder_inout_dim)
        source_mask = torch.ones(batch_size, time_steps).long()
        target_tokens = {"tokens": torch.ones(batch_size, time_steps).long()}
        source_mask[0, 1:] = 0
        encoder_out = {"source_mask": source_mask, "encoder_outputs": encoded_state}

        auto_regressive_seq_decoder.forward(encoder_out, target_tokens)
        assert auto_regressive_seq_decoder.get_metrics()["BLEU"] == 1.388809517005903e-11
        assert auto_regressive_seq_decoder.get_metrics()["em"] == 0.0
        assert auto_regressive_seq_decoder.get_metrics()["f1"] == 1 / 3
Exemple #10
0
    def __init__(self,
                 vocab: Vocabulary,
                 token_embedder: TextFieldEmbedder,
                 document_encoder: Seq2VecEncoder,
                 utterance_encoder: Seq2VecEncoder,
                 context_encoder: Seq2SeqEncoder,
                 beam_size: int = None,
                 max_decoding_steps: int = 50,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True) -> None:
        super(MultiTurnHred, self).__init__(vocab)
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        self._beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=self._beam_size)

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        self._max_decoding_steps = max_decoding_steps

        # Dense embedding of word level tokens.
        self._token_embedder = token_embedder

        # Document word level encoder.
        self._document_encoder = document_encoder

        # Dialogue word level encoder.
        self._utterance_encoder = utterance_encoder

        # Sentence level encoder.
        self._context_encoder = context_encoder

        num_classes = self.vocab.get_vocab_size()

        document_output_dim = self._document_encoder.get_output_dim()
        utterance_output_dim = self._utterance_encoder.get_output_dim()
        context_output_dim = self._context_encoder.get_output_dim()
        decoder_output_dim = utterance_output_dim
        decoder_input_dim = token_embedder.get_output_dim() + document_output_dim + context_output_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = GRUCell(decoder_input_dim, decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(decoder_output_dim, num_classes)
Exemple #11
0
    def test_training(self):
        dir_main = os.path.abspath(os.path.join(
            __file__, "../.."))  # the root directory of project
        # embedding_path = os.path.join(dir_main, 'vocab', 'embedding.npy')
        # vocab_path = os.path.join(dir_main, 'vocab', 'vocab.json')
        # embeddings = np.load(embedding_path)
        # embeddings = torch.from_numpy(embeddings)
        # with open(vocab_path) as j:
        #     vocab = json.load(j)
        vocab = {
            '<start>': 0,
            '<pad>': 1,
            '<end>': 2,
            'to': 3,
            'do': 4,
            'it': 5,
            'go': 6,
            'as': 7
        }
        model = UpDownCaptioner(vocab=vocab, embed_dim=10)
        # model.double()
        model.cuda()

        test_input = torch.rand(1, 3, 224, 224).cuda()
        # test_input = [test_input[i, :, :, :] for i in range(test_input.shape[0])]
        caption = torch.tensor([[0, 5, 6, 3, 6, 2]]).long().cuda()
        # optimizer = optim.SGD(model.parameters(), lr=0.015, momentum=0.9, weight_decay=0.001)
        # lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda iteration: 1 - iteration / 500)
        optimizer = optim.Adam(params=filter(lambda p: p.requires_grad,
                                             model.parameters()),
                               lr=4e-4,
                               weight_decay=0.001)
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda iteration: 1 - iteration / 500)

        for epoch in range(500):
            model.train()
            output_dict = model(test_input, caption)
            loss = output_dict['loss']
            model.zero_grad()

            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), True)

            optimizer.step()
            lr_scheduler.step()

            model.eval()
            bleu_eval = BLEU(exclude_indices={0, 1, 2})
            output_dict = model(test_input)
            seq = output_dict['seq']
            bleu_eval(predictions=seq, gold_targets=caption)
            bleu = bleu_eval.get_metric()['BLEU']
            if epoch % 10 == 0:
                print(loss)
                print(bleu)
        self.assertEqual(True, True)
Exemple #12
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        """
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
            `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`.
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        """
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Exemple #13
0
 def __init__(
     self,
     vocab: Vocabulary,
     encoder=None,
     source_encoder=None,
     trainable: bool = True,
     regularizer: Optional[RegularizerApplicator] = None,
 ) -> None:
     super().__init__(vocab, regularizer)
     self._bleu = BLEU()
     self._perplexity = Perplexity()
Exemple #14
0
 def __init__(self, vocab: Vocabulary):
     super().__init__()  # type: ignore
     self.vocab = vocab
     self._nll = Average()
     self._ppl = WordPPL()
     self._start_index = self.vocab.get_token_index(START_SYMBOL)
     self._end_index = self.vocab.get_token_index(END_SYMBOL)
     self._pad_index = self.vocab.get_token_index(
         self.vocab._padding_token)  # noqa: WPS437
     self._bleu = BLEU(exclude_indices={
         self._pad_index, self._end_index, self._start_index
     })
Exemple #15
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
        indexer: PretrainedTransformerIndexer = None,
        encoder: Seq2SeqEncoder = None,
        **kwargs,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead."
        )
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
        self._beam_search = beam_search.construct(
            end_index=self._end_id, vocab=self.vocab, **beam_search_extras
        )

        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (
                encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size
            )
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Exemple #16
0
    def __init__(self,
                 vocab: Vocabulary,
                 input_size: int,
                 hidden_size: int,
                 loss_ratio: float = 1.0,
                 recurrency: nn.LSTM = None,
                 num_layers: int = None,
                 remove_sos: bool = True,
                 remove_eos: bool = False,
                 target_embedder: Embedding = None,
                 target_embedding_dim: int = None,
                 target_namespace: str = "tokens",
                 slow_decode: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(RNNTLayer, self).__init__(vocab, regularizer)
        import warprnnt_pytorch
        self.loss_ratio = loss_ratio
        self._remove_sos = remove_sos
        self._remove_eos = remove_eos
        self._slow_decode = slow_decode
        self._target_namespace = target_namespace
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                     self._target_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        self._loss = warprnnt_pytorch.RNNTLoss(blank=self._pad_index,
                                               reduction='mean')
        self._recurrency = recurrency or \
            nn.LSTM(input_size=target_embedding_dim,
                    hidden_size=hidden_size,
                    num_layers=num_layers,
                    batch_first=True)

        self._target_embedder = target_embedder or Embedding(
            self._num_classes, target_embedding_dim)
        self.w_enc = nn.Linear(input_size, hidden_size, bias=True)
        self.w_dec = nn.Linear(input_size, hidden_size, bias=False)
        self._proj = nn.Linear(hidden_size, self._num_classes)

        exclude_indices = {self._pad_index, self._end_index, self._start_index}
        self._wer: Metric = WER(exclude_indices=exclude_indices)
        self._bleu: Metric = BLEU(exclude_indices=exclude_indices)
        self._dal = Average()

        initializer(self)
    def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._model_name = model_name
        # We only instantiate this when we need it.
        self._tokenizer: Optional[PretrainedTransformerTokenizer] = None
        self.t5 = T5Module.from_pretrained_module(model_name)

        exclude_indices = {
            self.t5.pad_token_id,
            self.t5.decoder_start_token_id,
            self.t5.eos_token_id,
        }
        self._metrics = [
            ROUGE(exclude_indices=exclude_indices),
            BLEU(exclude_indices=exclude_indices),
        ]
Exemple #18
0
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)
Exemple #19
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: torch.nn.Module,
                 decoder: AdaptiveRNNCell,
                 word_projection: torch.nn.Module,
                 source_embedding: TokenEmbedder,
                 target_embedding: TokenEmbedder,
                 target_namespace: str = "target_tokens",
                 start_symbol: str = '<GO>',
                 eos_symbol: str = '<EOS>',
                 max_decoding_step: int = 50,
                 use_bleu: bool = True,
                 label_smoothing: Optional[float] = None,
                 enc_attention: Union[AllenNLPAttentionWrapper, SingleTokenMHAttentionWrapper, None] = None,
                 dec_hist_attn: Union[AllenNLPAttentionWrapper, SingleTokenMHAttentionWrapper, None] = None,
                 scheduled_sampling_ratio: float = 0.,
                 act_loss_weight: float = 1.,
                 prediction_dropout: float = .1,
                 embedding_dropout: float = .1,
                 ):
        super(AdaptiveSeq2Seq, self).__init__(vocab)
        self._enc_attn = enc_attention
        self._dec_hist_attn = dec_hist_attn
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._encoder = encoder
        self._decoder = decoder
        self._src_embedding = source_embedding
        self._tgt_embedding = target_embedding

        self._start_id = vocab.get_token_index(start_symbol, target_namespace)
        self._eos_id = vocab.get_token_index(eos_symbol, target_namespace)
        self._max_decoding_step = max_decoding_step

        self._target_namespace = target_namespace
        self._label_smoothing = label_smoothing

        self._act_loss_weight = act_loss_weight

        self._pre_projection_dropout = torch.nn.Dropout(prediction_dropout)
        self._embedding_dropout = torch.nn.Dropout(embedding_dropout)
        self._output_projection = word_projection

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, target_namespace)
            self._bleu = BLEU(exclude_indices={pad_index, self._eos_id, self._start_id})
        else:
            self._bleu = None
Exemple #20
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sentence_encoder: Seq2VecEncoder,
                 claim_encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_steps: int = 100,
                 beam_size: int = 5,
                 beta: float = 1.0) -> None:
        super(Seq2SeqClaimRank, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.sentence_encoder = sentence_encoder
        self.claim_encoder = TimeDistributed(claim_encoder)  # Handles additional sequence dim
        self.claim_encoder_dim = claim_encoder.get_output_dim()
        self.attention = attention
        self.decoder_embedding_dim = text_field_embedder.get_output_dim()
        self.max_steps = max_steps
        self.beam_size = beam_size
        self.beta = beta

        # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim)

        # Since we are using the sentence encoding as the initial hidden state to the decoder, the
        # decoder hidden dim must match the sentence encoder hidden dim.
        self.decoder_output_dim = sentence_encoder.get_output_dim()
        self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim,
                                                self.decoder_output_dim)
        self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim,
                                                self.decoder_output_dim)

        # When projecting out we will use attention to combine claim embeddings into a single
        # context embedding, this will be concatenated with the decoder cell output before being
        # fed to the projection layer. Hence the expected input size is:
        #   decoder output dim + claim encoder output dim
        projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim
        self.output_projection_layer = torch.nn.Linear(projection_input_dim,
                                                       vocab.get_vocab_size())

        self._start_index = self.vocab.get_token_index('<s>')
        self._end_index = self.vocab.get_token_index('</s>')

        self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size)
        pad_index = vocab.get_token_index(vocab._padding_token)
        self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index})
        self.avg_reconstruction_loss = Average()
        self.avg_claim_scoring_loss = Average()
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        upsample: torch.nn.Module = None, 
        net: Seq2SeqEncoder = None,
        target_namespace: str = "target_tokens",
        target_embedding_dim: int = None,
        use_bleu: bool = True,
        loss_type: str = "ctc",
        label_smoothing: float = None,
    ) -> None:
        super(LatentAignmentCTC, self).__init__(vocab)
        self._target_namespace = target_namespace
        
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, 
                                                    self._target_namespace)
        self._blank_index = self.vocab.get_token_index(SPECIAL_BLANK_TOKEN, 
                                                       self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={self._pad_index, self._blank_index})
        else:
            self._bleu = None

        self._source_embedder = source_embedder
        source_embedding_dim = source_embedder.get_output_dim()

        self._upsample = upsample or LinearUpsample(source_embedding_dim, s = 3)
        self._net = net or StackedSelfAttentionEncoder(input_dim = source_embedding_dim,
                                                       hidden_dim = 128,
                                                       projection_dim = 128,
                                                       feedforward_hidden_dim = 512,
                                                       num_layers = 4,
                                                       num_attention_heads = 4)
        
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        target_embedding_dim = self._net.get_output_dim()

        self._output_projection = torch.nn.Linear(target_embedding_dim, num_classes)
        self.loss_type = loss_type
        self.label_smoothing = label_smoothing
Exemple #22
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model_path,
                 beam_size=5,
                 max_decoding_steps=140,
                 indexer=None):
        super().__init__(vocab)
        self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path)
        self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens")
        ##
        self._start_id = self.plm.config.decoder_start_token_id
        ##
        self._end_id = self.plm.config.eos_token_id  #
        self._decoder_start_id = self.plm.config.decoder_start_token_id
        self._end_id = self.plm.config.eos_token_id  #
        self._pad_id = self.plm.config.pad_token_id  #

        self._beam_search = BeamSearch(
            self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1
        )
        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})
Exemple #23
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
    def __init__(self, vocab: Vocabulary, max_timesteps: int = 50, encoder_size: int = 14, encoder_dim: int = 512, 
                 embedding_dim: int = 64, attention_dim: int = 64, decoder_dim: int = 64, beam_size: int = 3, teacher_forcing: bool = True) -> None:
        super().__init__(vocab)
        
        self._max_timesteps = max_timesteps
        
        self._vocab_size = self.vocab.get_vocab_size()
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        # POSSIBLE CHANGE LATER
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')
        
        self._encoder_size = encoder_size
        self._encoder_dim = encoder_dim
        self._embedding_dim = embedding_dim
        self._attention_dim = attention_dim
        self._decoder_dim = decoder_dim
        
        self._beam_size = beam_size
        self._teacher_forcing = teacher_forcing

        self._init_h = nn.Linear(self._encoder_dim, self._decoder_dim)
        self._init_c = nn.Linear(self._encoder_dim, self._decoder_dim)
        
        self._resnet = torchvision.models.resnet18()
        modules = list(self._resnet.children())[:-2]
        self._encoder = nn.Sequential(
            *modules,
            nn.AdaptiveAvgPool2d(self._encoder_size)
        )

        self._decoder = ImageCaptioningDecoder(self._vocab_size, self._encoder_dim, self._embedding_dim, self._attention_dim, self._decoder_dim)
        
        self.beam_search = BeamSearch(self._end_index, self._max_timesteps, self._beam_size)
        
        self._bleu = BLEU(exclude_indices={self._start_index, self._end_index, self._pad_index})
        self._exprate = Exprate(self._end_index)
Exemple #25
0
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: torch.nn.Module,
        decoder: torch.nn.Module,
        source_embedding: TokenEmbedder,
        target_embedding: TokenEmbedder,
        target_namespace: str = "target_tokens",
        start_symbol: str = '<GO>',
        eos_symbol: str = '<EOS>',
        max_decoding_step: int = 50,
        use_bleu: bool = True,
        label_smoothing: Optional[float] = None,
    ):
        super(ParallelSeq2Seq, self).__init__(vocab)
        self._encoder = encoder
        self._decoder = decoder
        self._src_embedding = source_embedding
        self._tgt_embedding = target_embedding

        self._start_id = vocab.get_token_index(start_symbol, target_namespace)
        self._eos_id = vocab.get_token_index(eos_symbol, target_namespace)
        self._max_decoding_step = max_decoding_step

        self._target_namespace = target_namespace
        self._label_smoothing = label_smoothing

        self._output_projection_layer = torch.nn.Linear(
            decoder.hidden_dim, vocab.get_vocab_size(target_namespace))

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)
            self._bleu = BLEU(
                exclude_indices={pad_index, self._eos_id, self._start_id})
        else:
            self._bleu = None
    def test_eval(self):
        dir_main = os.path.abspath(os.path.join(__file__, "../.."))

        eval_set_path = os.path.join(dir_main, 'dataset', 'VAL.hdf5')
        eval_set = CaptionDataset(eval_set_path)
        eval_loader = DataLoader(dataset=eval_set, batch_size=10)

        embedding_path = os.path.join(dir_main, 'vocab', 'embedding.npy')
        vocab_path = os.path.join(dir_main, 'vocab', 'vocab.json')
        embeddings = np.load(embedding_path)
        embeddings = torch.from_numpy(embeddings)
        with open(vocab_path) as j:
            vocab = json.load(j)

        model_path = os.path.join(dir_main, 'UpDown.pth')
        model = UpDownCaptioner(vocab=vocab, pre_trained_embedding=embeddings)
        model.load_state_dict(torch.load(model_path))
        model.cuda()
        model.eval()

        bleu_eval = BLEU(
            exclude_indices={vocab['<start>'], vocab['<end>'], vocab['<pad>']},
            ngram_weights=[0.5, 0.5, 0, 0])

        with torch.no_grad():
            for data_batch in tqdm(eval_loader):
                # load the batch data
                imgs, caps = data_batch['image'], data_batch['caption']
                imgs = imgs.cuda()

                output_dict = model(imgs)
                seq = output_dict['seq']
                bleu_eval(predictions=seq, gold_targets=caps)
        bleu_score = bleu_eval.get_metric()['BLEU']
        print(bleu_score)
        self.assertEqual(True, True)
Exemple #27
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Encoder,
                 decoder: CaptioningDecoder,
                 max_timesteps: int = 75,
                 teacher_forcing: bool = True,
                 scheduled_sampling_ratio: float = 1,
                 beam_size: int = 10) -> None:
        super().__init__(vocab)

        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')

        self._max_timesteps = max_timesteps
        self._teacher_forcing = teacher_forcing
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._beam_size = beam_size

        self._encoder = encoder
        self._decoder = decoder

        self._init_h = nn.Linear(self._encoder.get_output_dim(),
                                 self._decoder.get_input_dim())
        self._init_c = nn.Linear(self._encoder.get_output_dim(),
                                 self._decoder.get_input_dim())

        self.beam_search = BeamSearch(self._end_index, self._max_timesteps,
                                      self._beam_size)

        self._bleu = BLEU(exclude_indices={
            self._start_index, self._end_index, self._pad_index
        })
        self._exprate = Exprate(self._end_index, self.vocab)

        self._attention_weights = None
Exemple #28
0
    def __init__(
        self,
        vocab: Vocabulary,
        attention: Attention,
        beam_size: int,
        max_decoding_steps: int,
        target_embedding_dim: int = 30,
        copy_token: str = "@COPY@",
        source_namespace: str = "bert",
        target_namespace: str = "target_tokens",
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
        initializer: InitializerApplicator = InitializerApplicator(),
    ) -> None:
        super().__init__(vocab)
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._src_start_index = self.vocab.get_token_index(
            START_SYMBOL, self._source_namespace)
        self._src_end_index = self.vocab.get_token_index(
            END_SYMBOL, self._source_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._oov_index = self.vocab.get_token_index(self.vocab._oov_token,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)
        self._copy_index = self.vocab.add_token_to_namespace(
            copy_token, self._target_namespace)

        self._tensor_based_metric = tensor_based_metric or BLEU(
            exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        self._token_based_metric = token_based_metric

        self._target_vocab_size = self.vocab.get_vocab_size(
            self._target_namespace)

        # Encoding modules.
        bert_token_embedding = PretrainedBertEmbedder('bert-base-uncased',
                                                      requires_grad=True)

        self._source_embedder = bert_token_embedding
        self._encoder = PassThroughEncoder(
            input_dim=self._source_embedder.get_output_dim())

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        # We arbitrarily set the decoder's input dimension to be the same as the output dimension.
        self.encoder_output_dim = self._encoder.get_output_dim()
        self.decoder_output_dim = self.encoder_output_dim
        self.decoder_input_dim = self.decoder_output_dim

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # The decoder input will be a function of the embedding of the previous predicted token,
        # an attended encoder hidden state called the "attentive read", and another
        # weighted sum of the encoder hidden state called the "selective read".
        # While the weights for the attentive read are calculated by an `Attention` module,
        # the weights for the selective read are simply the predicted probabilities
        # corresponding to each token in the source sentence that matches the target
        # token from the previous timestep.
        self._target_embedder = Embedding(target_vocab_size,
                                          target_embedding_dim)
        self._attention = attention
        self._input_projection_layer = Linear(
            target_embedding_dim + self.encoder_output_dim * 2,
            self.decoder_input_dim)

        # We then run the projected decoder input through an LSTM cell to produce
        # the next hidden state.
        self._decoder_cell = LSTMCell(self.decoder_input_dim,
                                      self.decoder_output_dim)

        # We create a "generation" score for each token in the target vocab
        # with a linear projection of the decoder hidden state.
        self._output_generation_layer = Linear(self.decoder_output_dim,
                                               target_vocab_size)

        # We create a "copying" score for each source token by applying a non-linearity
        # (tanh) to a linear projection of the encoded hidden state for that token,
        # and then taking the dot product of the result with the decoder hidden state.
        self._output_copying_layer = Linear(self.encoder_output_dim,
                                            self.decoder_output_dim)

        # At prediction time, we'll use a beam search to find the best target sequence.
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 max_decoding_steps: int,
                 decoder_net: DecoderNet,
                 target_embedder: Embedding,
                 loss_criterion: LossCriterion,
                
                 generation_batch_size: int = 200,
                 use_in_seq2seq_mode: bool = False,
                 target_namespace: str = "tokens",
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 scheduled_sampling_k: int = 100,
                 scheduled_sampling_type: str = 'uniform',
                 rollin_mode: str = 'mixed',
                 rollout_mode: str = 'learned',

                 dropout: float = None,
                 start_token: str = START_SYMBOL,
                 end_token: str = END_SYMBOL,
                 num_decoder_layers: int = 1,
                 mask_pad_and_oov: bool = False,
                 tie_output_embedding: bool = False,

                 rollout_mixing_prob:float = 0.5,

                 use_bleu: bool = False,
                 use_hamming: bool = False,

                 sample_rollouts: bool = False,
                 beam_search_sampling_temperature: float = 1.,
                 top_k=0, 
                 top_p=0,
                 tensor_based_metric: Metric = None,
                 tensor_based_metric_mask: Metric = None,
                 token_based_metric: Metric = None,
                 eval_beam_size: int = 1,
                ) -> None:
        super().__init__(target_embedder)

        self.current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        self._vocab = vocab
        self._seq2seq_mode = use_in_seq2seq_mode

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._max_decoding_steps = max_decoding_steps
        self._generation_batch_size = generation_batch_size
        self._decoder_net = decoder_net

        self._target_namespace = target_namespace

        # TODO #4 (Kushal): Maybe make them modules so that we can add more of these later.
        # TODO #8 #7 (Kushal): Rename "mixed" rollin mode to "scheduled sampling".
        self._rollin_mode = rollin_mode
        self._rollout_mode = rollout_mode

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._scheduled_sampling_k = scheduled_sampling_k
        self._scheduled_sampling_type = scheduled_sampling_type
        self._sample_rollouts = sample_rollouts
        self._mask_pad_and_oov = mask_pad_and_oov

        self._rollout_mixing_prob = rollout_mixing_prob

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(start_token, self._target_namespace)
        self._end_index = self._vocab.get_token_index(end_token, self._target_namespace)

        self._padding_index = self._vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace)
        self._oov_index = self._vocab.get_token_index(DEFAULT_OOV_TOKEN, self._target_namespace)

        if self._mask_pad_and_oov:
            self._vocab_mask = torch.ones(self._vocab.get_vocab_size(self._target_namespace),
                                            device=self.current_device) \
                                    .scatter(0, torch.tensor([self._padding_index, self._oov_index, self._start_index],
                                                                device=self.current_device),
                                                0)
        if use_bleu:
            pad_index = self._vocab.get_token_index(self._vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        if use_hamming:
            self._hamming = HammingLoss()
        else:
            self._hamming = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1

        # TODO(Kushal): Pass in the arguments for sampled. Also, make sure you do not sample in case of Seq2Seq models.
        self._beam_search = SampledBeamSearch(self._end_index, 
                                                max_steps=max_decoding_steps, 
                                                beam_size=beam_size, temperature=beam_search_sampling_temperature)

        self._num_classes = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input." + 
                    f"target_embedder_dim: {self.target_embedder.get_output_dim()}, " + 
                    f"decoder input dim: {self._decoder_net.target_embedding_dim}."
            )

        self._ss_ratio = Average()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self.training_iteration = 0
        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_net.get_output_dim(), self._num_classes)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    f"Can't tie embeddings with output linear layer, due to shape mismatch. " + 
                    f"{self._output_projection_layer.weight.shape} and {self.target_embedder.weight.shape}"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._loss_criterion = loss_criterion

        self._top_k = top_k
        self._top_p = top_p
        self._eval_beam_size = eval_beam_size
        self._mle_loss = MaximumLikelihoodLossCriterion()
        self._perplexity = Perplexity()

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._tensor_based_metric_mask = tensor_based_metric_mask
        
        self._decode_tokens = partial(decode_tokens, 
                                    vocab=self._vocab,
                                    start_index=self._start_index,
                                    end_index=self._end_index)
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
                 attention: Attention = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 use_bleu: bool = True,
                 bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25,
                                                        0.25),
                 target_pretrain_file: str = None,
                 target_decoder_layers: int = 1,
                 **kwargs) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace
        self._target_decoder_layers = target_decoder_layers
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)
            self._bleu = BLEU(
                bleu_ngram_weights,
                exclude_indices={
                    pad_index, self._end_index, self._start_index
                },
            )
        else:
            self._bleu = None

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead.")
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"),
                          DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"),
                          DeprecationWarning)
        self._beam_search = beam_search.construct(end_index=self._end_index,
                                                  vocab=self.vocab,
                                                  **beam_search_extras)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        if not target_pretrain_file:
            self._target_embedder = Embedding(
                num_embeddings=num_classes, embedding_dim=target_embedding_dim)
        else:
            self._target_embedder = Embedding(
                embedding_dim=target_embedding_dim,
                pretrained_file=target_pretrain_file,
                vocab_namespace=self._target_namespace,
                vocab=self.vocab,
            )

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        if self._target_decoder_layers > 1:
            self._decoder_cell = LSTM(
                self._decoder_input_dim,
                self._decoder_output_dim,
                self._target_decoder_layers,
            )
        else:
            self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                          self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)