Exemple #1
0
class TransformerModel(nn.Module):
    def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder_emb = nn.Embedding(ntoken_dec, ninp)
        self.decoder_out = nn.Linear(ninp, ntoken_dec)
        self.model = Transformer(d_model=ninp, dim_feedforward=nhid)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        tgt = self.decoder_emb(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        src_mask = src_mask != 1
        tgt_mask = tgt_mask != 1
        subseq_mask = self.model.generate_square_subsequent_mask(
            tgt.size(1)).to(tgt.device)
        output = self.model(src.transpose(0, 1),
                            tgt.transpose(0, 1),
                            tgt_mask=subseq_mask,
                            src_key_padding_mask=src_mask,
                            tgt_key_padding_mask=tgt_mask,
                            memory_key_padding_mask=src_mask)
        output = self.decoder_out(output)
        return output

    def greedy_decode(self, src, src_mask, sos_token, max_length=20):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        src_mask = src_mask != 1
        encoded = self.model.encoder(src.transpose(0, 1),
                                     src_key_padding_mask=src_mask)
        generated = encoded.new_full((encoded.size(1), 1),
                                     sos_token,
                                     dtype=torch.long)
        for i in range(max_length - 1):
            subseq_mask = self.model.generate_square_subsequent_mask(
                generated.size(1)).to(src.device)
            decoder_in = self.decoder_emb(generated) * math.sqrt(self.ninp)
            decoder_in = self.pos_encoder(decoder_in)
            logits = self.decoder_out(
                self.model.decoder(decoder_in.transpose(0, 1),
                                   encoded,
                                   tgt_mask=subseq_mask,
                                   memory_key_padding_mask=src_mask)[-1, :, :])
            new_generated = logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, new_generated], dim=-1)
        return generated

    def save(self, file_dir):
        torch.save(self.state_dict(), file_dir)

    def load(self, file_dir):
        self.load_state_dict(torch.load(file_dir))
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
class SimpleTransformerModel(nn.Module):
    def __init__(self, nb_tokens: int, emb_size: int, nb_layers=2, nb_heads=4, hid_size=512, dropout=0.25, max_len=30):
        super(SimpleTransformerModel, self).__init__()
        from torch.nn import Transformer
        self.emb_size = emb_size
        self.max_len = max_len

        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_len)
        self.embedder = nn.Embedding(nb_tokens, emb_size)

        self.transformer = Transformer(d_model=emb_size, nhead=nb_heads, num_encoder_layers=nb_layers,
                                       num_decoder_layers=nb_layers, dim_feedforward=hid_size, dropout=dropout)

        self.out_lin = nn.Linear(in_features=emb_size, out_features=nb_tokens)

        self.tgt_mask = None

    def _generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).to(device)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def enc_forward(self, src):
        # Embed source
        src = self.embedder(src) * math.sqrt(self.emb_size)
        # Add positional encoding + reshape into format (seq element, batch element, embedding)
        src = self.pos_encoder(src.view(src.shape[0], 1, src.shape[1]))

        # Push through encoder
        output = self.transformer.encoder(src)

        return output

    def dec_forward(self, memory, tgt):
        # Generate target mask, if necessary
        if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt):
            mask = self._generate_square_subsequent_mask(len(tgt)).to(device)
            self.tgt_mask = mask

        # Embed target
        tgt = self.embedder(tgt) * math.sqrt(self.emb_size)
        # Add positional encoding + reshape into format (seq element, batch element, embedding)
        tgt = self.pos_encoder(tgt.view(tgt.shape[0], 1, tgt.shape[1]))

        # Push through decoder + linear output layer
        output = self.out_lin(self.transformer.decoder(memory=memory, tgt=tgt, tgt_mask=self.tgt_mask))
        # If using the model to evaluate, also take softmax of final layer to obtain probabilities
        if not self.training:
            output = torch.nn.functional.softmax(output, 2)

        return output

    def forward(self, src, tgt):
        memory = self.enc_forward(src)
        output = self.dec_forward(memory, tgt)

        return output

    def greedy_decode(self, src, max_len=None, start_symbol=0, stop_symbol=None):
        """
        Greedy decode input "src": generate output character one at a time, until "stop_symbol" is generated or
        the output exceeds max_len, whichever comes first.

        :param src: input src, 1D tensor
        :param max_len: int
        :param start_symbol: int
        :param stop_symbol: int
        :return: decoded output sequence
        """
        b_training = self.training
        if b_training:
            self.eval()

        if max_len is None:
            max_len = self.max_len
        elif max_len > self.max_len:
            raise ValueError(f"Parameter 'max_len' can not exceed model's own max_len,"
                             f" which is set at {self.max_len}.")
        # Get memory = output from encoder layer
        memory = model.enc_forward(src)
        # Initiate output buffer
        idxs = [start_symbol]
        # Keep track of last predicted symbol
        next_char = start_symbol
        # Keep generating output until "stop_symbol" is generated, or max_len is reached
        while next_char != stop_symbol:
            if len(idxs) == max_len:
                break
            # Convert output buffer to tensor
            ys = torch.LongTensor(idxs).to(device)
            # Push through decoder
            out = self.dec_forward(memory=memory, tgt=ys)
            # Get position of max probability of newly predicted character
            _, next_char = torch.max(out[-1, :, :], dim=1)
            next_char = next_char.item()

            # Append generated character to output buffer
            idxs.append(next_char)

        if b_training:
            self.train()

        return idxs
Exemple #4
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        # self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.token_embeddings = nn.Embedding(vocab_size,
                                             hidden_size,
                                             padding_idx=1)
        self.position_embeddings = PositionalEncoding(hidden_size)
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=intermediate_size,
            dropout=dropout,
        )

        self.decoder_embeddings = nn.Linear(hidden_size, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src=None,
                tgt=None,
                memory=None,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        if src is not None:
            src_embeddings = self.token_embeddings(src) * math.sqrt(
                self.hidden_size) + self.position_embeddings(src)
            src_embeddings = self.dropout(src_embeddings)

            if src_key_padding_mask is not None:
                src_key_padding_mask = src_key_padding_mask.t()

            if tgt is None:  # encode
                memory = self.transformer.encoder(
                    src_embeddings, src_key_padding_mask=src_key_padding_mask)
                return memory

        if tgt is not None:
            tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
                self.hidden_size) + self.position_embeddings(tgt)
            tgt_embeddings = self.dropout(tgt_embeddings)
            tgt_mask = self.transformer.generate_square_subsequent_mask(
                tgt.size(0)).to(tgt.device)

            if tgt_key_padding_mask is not None:
                tgt_key_padding_mask = tgt_key_padding_mask.t()

            if src is None and memory is not None:  # decode
                if memory_key_padding_mask is not None:
                    memory_key_padding_mask = memory_key_padding_mask.t()

                output = self.transformer.decoder(
                    tgt_embeddings,
                    memory,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask)
                output = self.decoder_embeddings(output)

                return output

        assert not (src is None and tgt is None)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output
Exemple #5
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 max_len,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = PositionalEncoding(d_model, max_len)
        self.hidden_size = d_model
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(d_model=d_model,
                                       nhead=num_attention_heads,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=intermediate_size,
                                       dropout=dropout)

        self.decoder_embeddings = nn.Linear(d_model, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src,
                tgt,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)

        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)

        output = self.decoder_embeddings(output)
        return output

    def encode(self, src, src_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        memory = self.transformer.encoder(
            src_embeddings, src_key_padding_mask=src_key_padding_mask)
        return memory

    def decode(self,
               tgt,
               memory,
               tgt_key_padding_mask=None,
               memory_key_padding_mask=None):
        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)
        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)

        output = self.transformer.decoder(
            tgt_embeddings,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output
Exemple #6
0
class MyTransformer(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)

    def _get_mask(self, meta_data):
        mask = torch.zeros(1, len(meta_data),
                           self.vocab.get_vocab_size(
                               self._target_namespace)).float()
        for bidx, md in enumerate(meta_data):
            for k, v in self.vocab._token_to_index[
                    self._target_namespace].items():
                if 'position' in k and k not in md['avail_pos']:
                    mask[:, bidx, v] = float('-inf')
        return mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == False,
                                        float('-inf')).masked_fill(
                                            mask == True, float(0.0))
        return mask

    @overrides
    def forward(
        self,
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
        meta_data: Any = None,
    ) -> Dict[str, torch.Tensor]:
        src, src_key_padding_mask = self._encode(self._source_embedder,
                                                 source_tokens)
        memory = self._transformer.encoder(
            src, src_key_padding_mask=src_key_padding_mask)

        if meta_data is not None:
            target_vocab_mask = self._get_mask(meta_data)
            target_vocab_mask = target_vocab_mask.to(memory.device)
        else:
            target_vocab_mask = None
        output_dict = {}
        targets = None
        if target_tokens:
            targets = target_tokens["tokens"][:, 1:]
            target_mask = (util.get_text_field_mask({"tokens": targets}) == 1)
            assert targets.size(1) <= self._max_decoding_steps
        if self.training and target_tokens:
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": target_tokens["tokens"][:, :-1]})
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask)
            logits = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                logits += target_vocab_mask
            class_probabilities = F.softmax(logits.detach(), dim=-1)
            _, predictions = torch.max(class_probabilities, -1)
            logits = logits.transpose(0, 1)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
        else:
            assert self.training is False
            output_dict["loss"] = torch.tensor(0.0).to(memory.device)
            if targets is not None:
                max_target_len = targets.size(1)
            else:
                max_target_len = None
            predictions, class_probabilities = self._decoder_step_by_step(
                memory,
                src_key_padding_mask,
                target_vocab_mask,
                max_target_len=max_target_len)
        predictions = predictions.transpose(0, 1)
        output_dict["predictions"] = predictions
        output_dict["class_probabilities"] = class_probabilities.transpose(
            0, 1)

        if target_tokens:
            with torch.no_grad():
                best_predictions = output_dict["predictions"]
                if self._bleu:
                    self._bleu(best_predictions, targets)
                batch_size = targets.size(0)
                max_sz = max(best_predictions.size(1), targets.size(1),
                             target_mask.size(1))
                best_predictions_ = torch.zeros(batch_size,
                                                max_sz).to(memory.device)
                best_predictions_[:, :best_predictions.
                                  size(1)] = best_predictions
                targets_ = torch.zeros(batch_size, max_sz).to(memory.device)
                targets_[:, :targets.size(1)] = targets.cpu()
                target_mask_ = torch.zeros(batch_size,
                                           max_sz).to(memory.device)
                target_mask_[:, :target_mask.size(1)] = target_mask
                self._seq_acc(best_predictions_.unsqueeze(1), targets_,
                              target_mask_)
        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            # shape: (batch_size, num_decoding_steps)
            predicted_indices = predicted_indices.detach().cpu().numpy()
            # class_probabilities = output_dict["class_probabilities"].detach().cpu()
            # sample_predicted_indices = []
            # for cp in class_probabilities:
            #     sample = torch.multinomial(cp, num_samples=1)
            #     sample_predicted_indices.append(sample)
            # # shape: (batch_size, num_decoding_steps, num_samples)
            # sample_predicted_indices = torch.stack(sample_predicted_indices)

        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            self, embedder: TextFieldEmbedder,
            tokens: Dict[str,
                         torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        src = embedder(tokens) * math.sqrt(self._ndim)
        src = src.transpose(0, 1)
        src = self.pos_encoder(src)
        mask = util.get_text_field_mask(tokens)
        mask = (mask == 0)
        return src, mask

    def _decoder_step_by_step(
            self,
            memory: torch.Tensor,
            memory_key_padding_mask: torch.Tensor,
            target_vocab_mask: torch.Tensor = None,
            max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = memory.size(1)
        if getattr(self, "target_limit_decode_steps",
                   False) and max_target_len is not None:
            num_decoding_steps = min(self._max_decoding_steps, max_target_len)
            print('decoding steps: ', num_decoding_steps)
        else:
            num_decoding_steps = self._max_decoding_steps

        last_predictions = memory.new_full(
            (batch_size, ), fill_value=self._start_index).long()

        step_predictions: List[torch.Tensor] = []
        all_predicts = memory.new_full((batch_size, num_decoding_steps),
                                       fill_value=0).long()
        for timestep in range(num_decoding_steps):
            all_predicts[:, timestep] = last_predictions
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": all_predicts[:, :timestep + 1]})
            tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
            output_projections = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                output_projections += target_vocab_mask

            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, -1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes[timestep, :]
            step_predictions.append(last_predictions)
            if ((last_predictions == self._end_index) +
                (last_predictions == self._pad_index)).all():
                break

        # shape: (num_decoding_steps, batch_size)
        predictions = torch.stack(step_predictions)
        return predictions, class_probabilities

    @staticmethod
    def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor,
                  target_mask: torch.FloatTensor) -> torch.Tensor:
        logits = logits.contiguous()
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets.contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask.contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset)
        return all_metrics

    def load_state_dict(self, state_dict, strict=True):
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[len('module.'):]] = v
            else:
                new_state_dict[k] = v

        super(MyTransformer, self).load_state_dict(new_state_dict, strict)
Exemple #7
0
class FullTransformer(Module):
    def __init__(self,
                 num_vocab,
                 num_embedding=128,
                 dim_feedforward=512,
                 num_encoder_layer=4,
                 num_decoder_layer=4,
                 dropout=0.3,
                 padding_idx=1,
                 max_seq_len=140,
                 nhead=8):
        super(FullTransformer, self).__init__()

        self.padding_idx = padding_idx

        # [x : seq_len,  batch_size ]
        self.inp_embedding = Embedding(num_vocab,
                                       num_embedding,
                                       padding_idx=padding_idx)

        # [ x : seq_len, batch_size, num_embedding ]
        self.pos_embedding = PositionalEncoding(num_embedding,
                                                dropout,
                                                max_len=max_seq_len)

        self.trfm = Transformer(d_model=num_embedding,
                                dim_feedforward=dim_feedforward,
                                num_encoder_layers=num_encoder_layer,
                                num_decoder_layers=num_decoder_layer,
                                dropout=dropout,
                                nhead=nhead)
        self.linear_out = torch.nn.Linear(num_embedding, num_vocab)

    def make_pad_mask(self, inp: torch.Tensor) -> torch.Tensor:
        """
        Make mask attention that caused 'True' element will not be attended (ignored).
        Padding stated in self.padding_idx will not be attended at all.

        :param inp : input that to be masked in boolean Tensor
        """
        return (inp == self.padding_idx).transpose(0, 1)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """
        forward!

        :param src : source tensor
        :param tgt : target tensor
        """
        # Generate mask for decoder attention
        tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to(
            tgt.device)

        # trg_mask shape = [target_seq_len, target_seq_len]
        src_pad_mask = self.make_pad_mask(src)
        tgt_pad_mask = self.make_pad_mask(tgt)

        # [ src : seq_len, batch_size, num_embedding ]

        out_emb_enc = self.pos_embedding(self.inp_embedding(src))

        # [ src : seq_len, batch_size, num_embedding ]
        out_emb_dec = self.pos_embedding(self.inp_embedding(tgt))

        out_trf = self.trfm(out_emb_enc,
                            out_emb_dec,
                            src_mask=None,
                            tgt_mask=tgt_mask,
                            memory_mask=None,
                            src_key_padding_mask=src_pad_mask,
                            tgt_key_padding_mask=tgt_pad_mask,
                            memory_key_padding_mask=src_pad_mask)

        # [ out_trf : seq_len, batch_size, num_embedding]

        out_to_logit = self.linear_out(out_trf)

        # final_out : [ seq_len, batch_size, vocab_size ]
        return out_to_logit

    def forward_encoder(
            self, src: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        src_pad_mask = self.make_pad_mask(src)
        out_emb_enc = self.pos_embedding(self.inp_embedding(src))
        return self.trfm.encoder(
            out_emb_enc, src_key_padding_mask=src_pad_mask), src_pad_mask

    def forward_decoder(self, tgt: torch.Tensor, memory: torch.Tensor,
                        src_pad_mask: torch.Tensor) -> torch.Tensor:
        tgt_pad_mask = self.make_pad_mask(tgt)
        out_emb_dec = self.pos_embedding(self.inp_embedding(tgt))
        tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to(
            tgt.device)
        out_trf = self.trfm.decoder(out_emb_dec,
                                    memory,
                                    tgt_mask=tgt_mask,
                                    tgt_key_padding_mask=tgt_pad_mask,
                                    memory_key_padding_mask=src_pad_mask)
        out_trf = self.linear_out(out_trf)
        return out_trf