示例#1
0
 def __init__(self,
              vocab: Vocabulary,
              source_embedder: TextFieldEmbedder,
              encoder: Seq2SeqEncoder,
              max_decoding_steps: int,
              attention: Attention = None,
              attention_function: SimilarityFunction = None,
              beam_size: int = None,
              target_namespace: str = "tokens",
              target_embedding_dim: int = None,
              scheduled_sampling_ratio: float = 0.,
              use_bleu: bool = True) -> None:
     super(Seq2Seq,
           self).__init__(vocab,
                          source_embedder,
                          encoder,
                          max_decoding_steps,
                          attention=attention,
                          attention_function=attention_function,
                          beam_size=beam_size,
                          target_namespace=target_namespace,
                          target_embedding_dim=target_embedding_dim,
                          scheduled_sampling_ratio=scheduled_sampling_ratio,
                          use_bleu=use_bleu)
     self._seqacc = SequenceAccuracy()
     self._train_seqacc = SequenceAccuracy()
    def test_sequence_accuracy(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([[1, 2, 3], [2, 4, 8], [0, 1, 1]])
        predictions = torch.Tensor([[[1, 2, 3], [1, 2, -1]],
                                    [[2, 4, 8], [2, 5, 9]],
                                    [[-1, -1, -1], [0, 1, -1]]])

        accuracy(predictions, gold)
        actual_accuracy = accuracy.get_metric()
        numpy.testing.assert_almost_equal(actual_accuracy, 2 / 3)
    def test_sequence_accuracy_accumulates_and_resets_correctly(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([[1, 2, 3]])
        accuracy(torch.Tensor([[[1, 2, 3]]]), gold)
        accuracy(torch.Tensor([[[1, 2, 4]]]), gold)

        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 1 / 2)
        assert accuracy.correct_count == 0
        assert accuracy.total_count == 0
    def test_sequence_accuracy_accumulates_and_resets_correctly(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([[1, 2, 3]])
        accuracy(torch.Tensor([[[1, 2, 3]]]), gold)
        accuracy(torch.Tensor([[[1, 2, 4]]]), gold)

        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 1/2)
        assert accuracy.correct_count == 0
        assert accuracy.total_count == 0
示例#5
0
    def test_sequence_accuracy_accumulates_and_resets_correctly(self, device: str):
        accuracy = SequenceAccuracy()
        gold = torch.tensor([[1, 2, 3]], device=device)
        accuracy(torch.tensor([[[1, 2, 3]]], device=device), gold)
        accuracy(torch.tensor([[[1, 2, 4]]], device=device), gold)

        actual_accuracy = accuracy.get_metric(reset=True)
        assert_allclose(actual_accuracy, 1 / 2)
        assert accuracy.correct_count == 0
        assert accuracy.total_count == 0
示例#6
0
    def test_sequence_accuracy(self, device: str):
        accuracy = SequenceAccuracy()
        gold = torch.tensor([[1, 2, 3], [2, 4, 8], [0, 1, 1]], device=device)
        predictions = torch.tensor(
            [[[1, 2, 3], [1, 2, -1]], [[2, 4, 8], [2, 5, 9]], [[-1, -1, -1], [0, 1, -1]]],
            device=device,
        )

        accuracy(predictions, gold)
        actual_accuracy = accuracy.get_metric()
        assert_allclose(actual_accuracy, 2 / 3)
示例#7
0
    def test_sequence_accuracy_respects_mask(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([[1, 2, 3], [2, 4, 8], [0, 1, 1], [11, 13, 17]])
        predictions = torch.Tensor([
            [[1, 2, 3], [1, 2, -1]],
            [[2, 4, 8], [2, 5, 9]],
            [[-1, -1, -1], [0, 1, -1]],
            [[12, 13, 17], [11, 13, 18]],
        ])
        mask = torch.Tensor([[0, 1, 1], [1, 1, 1], [1, 1, 0], [1, 0, 1]])

        accuracy(predictions, gold, mask)
        actual_accuracy = accuracy.get_metric()
        numpy.testing.assert_almost_equal(actual_accuracy, 3 / 4)
    def test_sequence_accuracy(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([
                [1, 2, 3],
                [2, 4, 8],
                [0, 1, 1]
        ])
        predictions = torch.Tensor([
                [[1, 2, 3], [1, 2, -1]],
                [[2, 4, 8], [2, 5, 9]],
                [[-1, -1, -1], [0, 1, -1]]
        ])

        accuracy(predictions, gold)
        actual_accuracy = accuracy.get_metric()
        numpy.testing.assert_almost_equal(actual_accuracy, 2/3)
示例#9
0
    def test_distributed_sequence_accuracy(self):
        gold = torch.tensor([[1, 2, 3], [2, 4, 8], [0, 1, 1], [11, 13, 17]])
        predictions = torch.tensor([
            [[1, 2, 3], [1, 2, -1]],
            [[2, 4, 8], [2, 5, 9]],
            [[-1, -1, -1], [0, 1, -1]],
            [[12, 13, 17], [11, 13, 18]],
        ])
        mask = torch.tensor([[False, True, True], [True, True, True],
                             [True, True, False], [True, False, True]], )
        gold = [gold[:2], gold[2:]]
        predictions = [predictions[:2], predictions[2:]]
        mask = [mask[:2], mask[2:]]

        metric_kwargs = {
            "predictions": predictions,
            "gold_labels": gold,
            "mask": mask
        }
        desired_values = {"accuracy": 3 / 4}
        run_distributed_test(
            [-1, -1],
            global_distributed_metric,
            SequenceAccuracy(),
            metric_kwargs,
            desired_values,
            exact=False,
        )
示例#10
0
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)
示例#11
0
    def test_sequence_accuracy_respects_mask(self, device: str):
        accuracy = SequenceAccuracy()
        gold = torch.tensor([[1, 2, 3], [2, 4, 8], [0, 1, 1], [11, 13, 17]],
                            device=device)
        predictions = torch.tensor(
            [
                [[1, 2, 3], [1, 2, -1]],
                [[2, 4, 8], [2, 5, 9]],
                [[-1, -1, -1], [0, 1, -1]],
                [[12, 13, 17], [11, 13, 18]],
            ],
            device=device,
        )
        mask = torch.tensor([[0, 1, 1], [1, 1, 1], [1, 1, 0], [1, 0, 1]],
                            device=device)

        accuracy(predictions, gold, mask)
        actual_accuracy = accuracy.get_metric()
        assert_allclose(actual_accuracy, 3 / 4)
示例#12
0
    def test_sequence_accuracy_respects_mask(self):
        accuracy = SequenceAccuracy()
        gold = torch.Tensor([
                [1, 2, 3],
                [2, 4, 8],
                [0, 1, 1],
                [11, 13, 17],
        ])
        predictions = torch.Tensor([
                [[1, 2, 3], [1, 2, -1]],
                [[2, 4, 8], [2, 5, 9]],
                [[-1, -1, -1], [0, 1, -1]],
                [[12, 13, 17], [11, 13, 18]]
        ])
        mask = torch.Tensor([
                [0, 1, 1],
                [1, 1, 1],
                [1, 1, 0],
                [1, 0, 1]
        ])

        accuracy(predictions, gold, mask)
        actual_accuracy = accuracy.get_metric()
        numpy.testing.assert_almost_equal(actual_accuracy, 3/4)
示例#13
0
def multiple_runs(
    global_rank: int,
    world_size: int,
    gpu_id: Union[int, torch.device],
    metric: SequenceAccuracy,
    metric_kwargs: Dict[str, List[Any]],
    desired_values: Dict[str, Any],
    exact: Union[bool, Tuple[float, float]] = True,
):

    kwargs = {}
    # Use the arguments meant for the process with rank `global_rank`.
    for argname in metric_kwargs:
        kwargs[argname] = metric_kwargs[argname][global_rank]

    for i in range(200):
        metric(**kwargs)

    assert desired_values["accuracy"] == metric.get_metric()["accuracy"]
示例#14
0
    def __init__(
            self,
            vocab: Vocabulary,
            source_embedder: TextFieldEmbedder,
            transformer: Dict,
            max_decoding_steps: int,
            target_embedders: Dict[str, TextFieldEmbedder] = None,
            loss_coefs: Dict = None,
    ) -> None:
        super().__init__(vocab)
        self._target_namespaces = loss_coefs.keys()
        self._decoder_namespaces = transformer.get("num_decoder_layers", {}).keys()
        self._start_index_dict = {k: self.vocab.get_token_index(START_SYMBOL, k) for k in self._decoder_namespaces}
        self._end_index_dict = {k: self.vocab.get_token_index(END_SYMBOL, k) for k in self._decoder_namespaces}
        self._pad_index_dict = {k: self.vocab.get_token_index(self.vocab._padding_token, k) for k in
                                self._target_namespaces}
        self._loss_coefs = loss_coefs
        self._metrics = {}
        for tn in self._target_namespaces:
            self._metrics[f'{tn}_acc'] = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim, transformer["dropout"])

        self._transformer = MultiDecodersTPTransformer(**transformer)
        self._transformer.apply(inplace_relu)

        self._target_embedders = ModuleDict(target_embedders.items())
        output_projection_layers = {}
        for tn in self._target_namespaces:
            num_classes = self.vocab.get_vocab_size(tn)
            output_projection_layers[tn] = Linear(self._ndim, num_classes)

        self._output_projection_layers = ModuleDict(output_projection_layers.items())

        # tp parameters
        self._embedding_project_layer = Linear(self._ndim, self._ndim)
示例#15
0
    def test_get_metric_on_new_object_works(self, device: str):
        accuracy = SequenceAccuracy()

        actual_accuracy = accuracy.get_metric(reset=True)["accuracy"]
        assert_allclose(actual_accuracy, 0)
示例#16
0
class Seq2Seq(SimpleSeq2Seq):

    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True) -> None:
        super(Seq2Seq, self).__init__(vocab, source_embedder,encoder,max_decoding_steps,
                                      attention=attention, attention_function=attention_function,
                                      beam_size=beam_size,target_namespace=target_namespace,
                                      target_embedding_dim=target_embedding_dim,
                                      scheduled_sampling_ratio=scheduled_sampling_ratio,
                                      use_bleu=use_bleu)
        self._seqacc = SequenceAccuracy()

    @overrides
    def forward(self,  # type: ignore
                nl: Dict[str, torch.LongTensor],
                fl: Dict[str, torch.LongTensor] = None, id=None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        source_tokens, target_tokens = nl, fl
        state = self._encode(source_tokens)

        if target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens)
        else:
            output_dict = {}

        if True: #not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            # shape: (batch_size, beam_size, max_sequence_length)
            top_k_predictions = output_dict["predictions"]
            # shape: (batch_size, max_predicted_sequence_length)
            best_predictions = top_k_predictions[:, 0, :]
            if target_tokens and self._bleu:
                self._bleu(best_predictions, target_tokens["tokens"])
            if target_tokens:
                seqacc_gold = target_tokens["tokens"][:, 1:]
                self._seqacc(best_predictions.unsqueeze(1)[:, :, :seqacc_gold.size(1)],
                             seqacc_gold,
                             mask=(seqacc_gold != 0).long())

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        all_metrics.update({"SeqAcc": self._seqacc.get_metric(reset=reset)})
        return all_metrics
示例#17
0
    def test_get_metric_on_new_object_works(self):
        accuracy = SequenceAccuracy()

        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0)
    def test_get_metric_on_new_object_works(self):
        accuracy = SequenceAccuracy()

        actual_accuracy = accuracy.get_metric(reset=True)
        numpy.testing.assert_almost_equal(actual_accuracy, 0)
示例#19
0
class MyTransformer(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)

    def _get_mask(self, meta_data):
        mask = torch.zeros(1, len(meta_data),
                           self.vocab.get_vocab_size(
                               self._target_namespace)).float()
        for bidx, md in enumerate(meta_data):
            for k, v in self.vocab._token_to_index[
                    self._target_namespace].items():
                if 'position' in k and k not in md['avail_pos']:
                    mask[:, bidx, v] = float('-inf')
        return mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == False,
                                        float('-inf')).masked_fill(
                                            mask == True, float(0.0))
        return mask

    @overrides
    def forward(
        self,
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
        meta_data: Any = None,
    ) -> Dict[str, torch.Tensor]:
        src, src_key_padding_mask = self._encode(self._source_embedder,
                                                 source_tokens)
        memory = self._transformer.encoder(
            src, src_key_padding_mask=src_key_padding_mask)

        if meta_data is not None:
            target_vocab_mask = self._get_mask(meta_data)
            target_vocab_mask = target_vocab_mask.to(memory.device)
        else:
            target_vocab_mask = None
        output_dict = {}
        targets = None
        if target_tokens:
            targets = target_tokens["tokens"][:, 1:]
            target_mask = (util.get_text_field_mask({"tokens": targets}) == 1)
            assert targets.size(1) <= self._max_decoding_steps
        if self.training and target_tokens:
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": target_tokens["tokens"][:, :-1]})
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask)
            logits = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                logits += target_vocab_mask
            class_probabilities = F.softmax(logits.detach(), dim=-1)
            _, predictions = torch.max(class_probabilities, -1)
            logits = logits.transpose(0, 1)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
        else:
            assert self.training is False
            output_dict["loss"] = torch.tensor(0.0).to(memory.device)
            if targets is not None:
                max_target_len = targets.size(1)
            else:
                max_target_len = None
            predictions, class_probabilities = self._decoder_step_by_step(
                memory,
                src_key_padding_mask,
                target_vocab_mask,
                max_target_len=max_target_len)
        predictions = predictions.transpose(0, 1)
        output_dict["predictions"] = predictions
        output_dict["class_probabilities"] = class_probabilities.transpose(
            0, 1)

        if target_tokens:
            with torch.no_grad():
                best_predictions = output_dict["predictions"]
                if self._bleu:
                    self._bleu(best_predictions, targets)
                batch_size = targets.size(0)
                max_sz = max(best_predictions.size(1), targets.size(1),
                             target_mask.size(1))
                best_predictions_ = torch.zeros(batch_size,
                                                max_sz).to(memory.device)
                best_predictions_[:, :best_predictions.
                                  size(1)] = best_predictions
                targets_ = torch.zeros(batch_size, max_sz).to(memory.device)
                targets_[:, :targets.size(1)] = targets.cpu()
                target_mask_ = torch.zeros(batch_size,
                                           max_sz).to(memory.device)
                target_mask_[:, :target_mask.size(1)] = target_mask
                self._seq_acc(best_predictions_.unsqueeze(1), targets_,
                              target_mask_)
        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            # shape: (batch_size, num_decoding_steps)
            predicted_indices = predicted_indices.detach().cpu().numpy()
            # class_probabilities = output_dict["class_probabilities"].detach().cpu()
            # sample_predicted_indices = []
            # for cp in class_probabilities:
            #     sample = torch.multinomial(cp, num_samples=1)
            #     sample_predicted_indices.append(sample)
            # # shape: (batch_size, num_decoding_steps, num_samples)
            # sample_predicted_indices = torch.stack(sample_predicted_indices)

        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            self, embedder: TextFieldEmbedder,
            tokens: Dict[str,
                         torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        src = embedder(tokens) * math.sqrt(self._ndim)
        src = src.transpose(0, 1)
        src = self.pos_encoder(src)
        mask = util.get_text_field_mask(tokens)
        mask = (mask == 0)
        return src, mask

    def _decoder_step_by_step(
            self,
            memory: torch.Tensor,
            memory_key_padding_mask: torch.Tensor,
            target_vocab_mask: torch.Tensor = None,
            max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = memory.size(1)
        if getattr(self, "target_limit_decode_steps",
                   False) and max_target_len is not None:
            num_decoding_steps = min(self._max_decoding_steps, max_target_len)
            print('decoding steps: ', num_decoding_steps)
        else:
            num_decoding_steps = self._max_decoding_steps

        last_predictions = memory.new_full(
            (batch_size, ), fill_value=self._start_index).long()

        step_predictions: List[torch.Tensor] = []
        all_predicts = memory.new_full((batch_size, num_decoding_steps),
                                       fill_value=0).long()
        for timestep in range(num_decoding_steps):
            all_predicts[:, timestep] = last_predictions
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": all_predicts[:, :timestep + 1]})
            tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
            output_projections = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                output_projections += target_vocab_mask

            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, -1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes[timestep, :]
            step_predictions.append(last_predictions)
            if ((last_predictions == self._end_index) +
                (last_predictions == self._pad_index)).all():
                break

        # shape: (num_decoding_steps, batch_size)
        predictions = torch.stack(step_predictions)
        return predictions, class_probabilities

    @staticmethod
    def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor,
                  target_mask: torch.FloatTensor) -> torch.Tensor:
        logits = logits.contiguous()
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets.contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask.contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset)
        return all_metrics

    def load_state_dict(self, state_dict, strict=True):
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[len('module.'):]] = v
            else:
                new_state_dict[k] = v

        super(MyTransformer, self).load_state_dict(new_state_dict, strict)