コード例 #1
0
 def __init__(self, input_shape):
     self.model = Sequential()
     [self.model.add(x) for x in Encoder(input_shape).model.layers]
     [self.model.add(x) for x in Attention(self.model.layers[-1].output_shape).model.layers]
     [self.model.add(x) for x in State(self.model.layers[-1].output_shape).model.layers]
     [self.model.add(x) for x in Decoder(self.model.layers[-1].output_shape).model.layers]
     self.model.add(Activation('softmax'))
コード例 #2
0
    def __init__(self,
                 word_vocab: Vocab,
                 bio_vocab: Vocab,
                 feat_vocab: Vocab,
                 word_embed_size,
                 bio_embed_size,
                 feat_embed_size,
                 hidden_size,
                 enc_bidir,
                 dropout=0.2):
        """ Init NMT Model.

        @param embed_size (int): Embedding size (dimensionality)
        @param hidden_size (int): Hidden Size (dimensionality)
        @param vocab (Vocab): Vocabulary object containing src and tgt languages
                              See vocab.py for documentation.
        @param dropout_rate (float): Dropout probability, for attention
        """
        super(NMT, self).__init__()
        self.word_vocab = word_vocab
        self.bio_vocab = bio_vocab
        self.feat_vocab = feat_vocab
        self.args = {
            'word_embed_size': word_embed_size,
            'bio_embed_size': bio_embed_size,
            'feat_embed_size': feat_embed_size,
            'hidden_size': hidden_size,
            'enc_bidir': enc_bidir,
            'dropout': dropout
        }
        self.embedding = FeatureRichEmbedding(len(word_vocab), word_embed_size,
                                              len(bio_vocab), bio_embed_size,
                                              len(feat_vocab), feat_embed_size)
        self.encoder = Encoder(
            word_embed_size + bio_embed_size + feat_embed_size * 3,
            hidden_size, dropout, enc_bidir)
        self.decoder_init_hidden_proj = nn.Linear(self.encoder.hidden_size,
                                                  hidden_size)
        self.decoder = Decoder(word_embed_size, hidden_size, hidden_size,
                               len(word_vocab), dropout)
コード例 #3
0
 def __init__(self, word_vocab: Vocab, bio_vocab: Vocab, feat_vocab: Vocab,
              albert: bool, word_embed_size, bio_embed_size,
              feat_embed_size, hidden_size, dropout, enc_bidir, n_head,
              max_out_cpy: bool, **kwargs):
     super(QGModel, self).__init__()
     self.word_vocab = word_vocab
     self.bio_vocab = bio_vocab
     self.feat_vocab = feat_vocab
     self.args = {
         'albert': albert,
         'word_embed_size': word_embed_size,
         'bio_embed_size': bio_embed_size,
         'feat_embed_size': feat_embed_size,
         'hidden_size': hidden_size,
         'dropout': dropout,
         'enc_bidir': enc_bidir,
         'n_head': n_head,
         'max_out_cpy': max_out_cpy
     }
     self.args.update(kwargs)
     if albert:
         self.embedding = AlbertFeatureRichEmbedding(
             kwargs['albert_model_name'], len(bio_vocab), bio_embed_size,
             len(feat_vocab), feat_embed_size, kwargs['albert_cache_dir'])
         decoder_word_embed_size = kwargs['albert_word_embed_size']
     else:
         self.embedding = FeatureRichEmbedding(len(word_vocab),
                                               word_embed_size,
                                               len(bio_vocab),
                                               bio_embed_size,
                                               len(feat_vocab),
                                               feat_embed_size)
         decoder_word_embed_size = word_embed_size
     self.encoder = Encoder(
         word_embed_size + bio_embed_size + feat_embed_size * 3,
         word_embed_size, hidden_size, dropout, enc_bidir, n_head)
     self.decoder = Decoder(decoder_word_embed_size, hidden_size,
                            hidden_size, len(word_vocab), dropout,
                            max_out_cpy)
コード例 #4
0
 def __init__(self,
              vocab,
              embed_size,
              hidden_size,
              enc_bidir,
              attn_size,
              dropout=0.2):
     super(QGModel, self).__init__()
     self.vocab = vocab
     self.args = {
         'embed_size': embed_size,
         'hidden_size': hidden_size,
         'dropout': dropout,
         'enc_bidir': enc_bidir,
         'attn_size': attn_size
     }
     self.embeddings = ModelEmbeddings(embed_size, vocab)
     self.encoder = Encoder(embed_size, hidden_size, dropout, enc_bidir)
     self.decoder_init_hidden_proj = nn.Linear(self.encoder.hidden_size,
                                               hidden_size)
     self.decoder = Decoder(embed_size, hidden_size, attn_size,
                            len(vocab.tgt), dropout)
コード例 #5
0
    def __init__(self,
                 data_format='channels_last',
                 groups=8,
                 reduction=2,
                 l2_scale=1e-5,
                 dropout=0.2,
                 downsampling='conv',
                 upsampling='conv',
                 base_filters=16,
                 depth=4,
                 in_ch=2,
                 out_ch=3):
        """ Initializes the model, a cross between the 3D U-net
            and 2018 BraTS Challenge top model with VAE regularization.

            References:
                - [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/pdf/1606.06650.pdf)
                - [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
        """
        super(Model, self).__init__()
        self.epoch = tf.Variable(0, name='epoch', trainable=False)
        self.encoder = Encoder(data_format=data_format,
                               groups=groups,
                               reduction=reduction,
                               l2_scale=l2_scale,
                               dropout=dropout,
                               downsampling=downsampling,
                               base_filters=base_filters,
                               depth=depth)
        self.decoder = Decoder(data_format=data_format,
                               groups=groups,
                               reduction=reduction,
                               l2_scale=l2_scale,
                               upsampling=upsampling,
                               base_filters=base_filters,
                               depth=depth,
                               out_ch=out_ch)
        self.vae = VariationalAutoencoder(data_format=data_format,
                                          groups=groups,
                                          reduction=reduction,
                                          l2_scale=l2_scale,
                                          upsampling=upsampling,
                                          base_filters=base_filters,
                                          depth=depth,
                                          out_ch=in_ch)
コード例 #6
0
class NMT(nn.Module):
    def __init__(self,
                 word_vocab: Vocab,
                 bio_vocab: Vocab,
                 feat_vocab: Vocab,
                 word_embed_size,
                 bio_embed_size,
                 feat_embed_size,
                 hidden_size,
                 enc_bidir,
                 dropout=0.2):
        """ Init NMT Model.

        @param embed_size (int): Embedding size (dimensionality)
        @param hidden_size (int): Hidden Size (dimensionality)
        @param vocab (Vocab): Vocabulary object containing src and tgt languages
                              See vocab.py for documentation.
        @param dropout_rate (float): Dropout probability, for attention
        """
        super(NMT, self).__init__()
        self.word_vocab = word_vocab
        self.bio_vocab = bio_vocab
        self.feat_vocab = feat_vocab
        self.args = {
            'word_embed_size': word_embed_size,
            'bio_embed_size': bio_embed_size,
            'feat_embed_size': feat_embed_size,
            'hidden_size': hidden_size,
            'enc_bidir': enc_bidir,
            'dropout': dropout
        }
        self.embedding = FeatureRichEmbedding(len(word_vocab), word_embed_size,
                                              len(bio_vocab), bio_embed_size,
                                              len(feat_vocab), feat_embed_size)
        self.encoder = Encoder(
            word_embed_size + bio_embed_size + feat_embed_size * 3,
            hidden_size, dropout, enc_bidir)
        self.decoder_init_hidden_proj = nn.Linear(self.encoder.hidden_size,
                                                  hidden_size)
        self.decoder = Decoder(word_embed_size, hidden_size, hidden_size,
                               len(word_vocab), dropout)
        ### END YOUR CODE

    def batch_to_tensor(self, batch: SquadBatch):
        src_indexes = [
            torch.tensor(x, dtype=torch.long,
                         device=self.device).transpose(0, 1)
            for x in (batch.src_index, batch.bio_index, batch.case_index,
                      batch.ner_index, batch.pos_index)
        ]
        src_len = torch.tensor(batch.src_len,
                               dtype=torch.int,
                               device=self.device)
        src_mask = self.generate_mask(src_len, src_indexes[0].size(0))
        tgt_index = torch.tensor(batch.tgt_index,
                                 dtype=torch.long,
                                 device=self.device).transpose(0, 1)
        return src_indexes, src_len, src_mask, tgt_index

    def forward(self, batch: SquadBatch) -> torch.Tensor:
        """ Take a mini-batch of source and target sentences, compute the log-likelihood of
        target sentences under the language models learned by the NMT system.

        @param source (List[List[str]]): list of source sentence tokens
        @param target (List[List[str]]): list of target sentence tokens, wrapped by `<s>` and `</s>`

        @returns scores (Tensor): a variable/tensor of shape (b, ) representing the
                                    log-likelihood of generating the gold-standard target sentence for
                                    each example in the input batch. Here b = batch size.
        """
        src_indexes, src_len, src_mask, tgt_index = self.batch_to_tensor(batch)
        src_embed = self.embedding(*src_indexes)
        memory, last_hidden = self.encoder(src_embed, src_len)
        memory = memory.transpose(0, 1)
        dec_init_hidden = torch.tanh(
            self.decoder_init_hidden_proj(last_hidden))
        tgt_embed = self.embedding.word_embeddings(
            tgt_index)  # (tgt_len, B, embed_size)
        gen_output = self.decoder(memory, src_mask, tgt_embed,
                                  dec_init_hidden)  # (tgt_len-1, B, hidden)
        return gen_output

    def generate_mask(self, length, max_length):
        mask = torch.zeros(length.size(0),
                           max_length,
                           dtype=torch.int,
                           device=self.device)
        for i, x in enumerate(length):
            mask[i, x:] = 1
        return mask

    def beam_search(self, batch, beam_size, max_decoding_step):
        """ Given a single source sentence, perform beam search, yielding translations in the target language.
        @param src_sent (List[str]): a single source sentence (words)
        @param beam_size (int): beam size
        @param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN
        @returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields:
                value: List[str]: the decoded target sentence, represented as a list of words
                score: float: the log-likelihood of the target sentence
        """
        src_indexes, src_len, src_mask, tgt_index = self.batch_to_tensor(batch)
        src_embed = self.embedding(*src_indexes)

        memory, last_hidden = self.encoder(src_embed, src_len)
        memory = memory.transpose(0, 1)
        memory_for_attn = self.decoder.memory_attn_proj(memory)

        dec_hidden_tm1 = torch.tanh(self.decoder_init_hidden_proj(last_hidden))
        ctxt_tm1 = torch.zeros_like(dec_hidden_tm1, device=self.device)

        hypotheses = [[self.word_vocab.SOS]]
        hyp_scores = torch.zeros(len(hypotheses),
                                 dtype=torch.float,
                                 device=self.device)
        completed_hypotheses = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < max_decoding_step:
            t += 1
            hyp_num = len(hypotheses)

            memory_tm1 = memory.expand(hyp_num, memory.size(1), memory.size(2))

            memory_for_attn_tm1 = memory_for_attn.expand(
                hyp_num, memory_for_attn.size(1), memory_for_attn.size(2))

            prev_word = torch.tensor(
                [self.word_vocab[hyp[-1]] for hyp in hypotheses],
                dtype=torch.long,
                device=self.device)  # (hpy_num, )
            tgt_tm1 = self.embedding.word_embeddings(prev_word)  # (hpy_num, e)

            gen_t, dec_hidden_t, ctxt_t = self.decoder.decode_step(
                tgt_tm1, ctxt_tm1, dec_hidden_tm1, memory_for_attn_tm1,
                memory_tm1)

            # log probabilities over target words
            log_p_t = F.log_softmax(gen_t, dim=-1)  # (hpy_num, vocab)

            live_hyp_num = beam_size - len(completed_hypotheses)
            contiuating_hyp_scores = (
                hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(
                    -1)  # (hpy_num * src_len)
            top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
                contiuating_hyp_scores, k=live_hyp_num)

            prev_hyp_ids = top_cand_hyp_pos / len(self.word_vocab)
            hyp_word_ids = top_cand_hyp_pos % len(self.word_vocab)
            new_hypotheses = []
            live_hyp_ids = []
            new_hyp_scores = []

            for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(
                    prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
                prev_hyp_id = prev_hyp_id.item()
                hyp_word_id = hyp_word_id.item()
                cand_new_hyp_score = cand_new_hyp_score.item()

                hyp_word = self.word_vocab.id2word[hyp_word_id]
                new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
                if hyp_word == self.word_vocab.EOS:
                    completed_hypotheses.append(
                        (new_hyp_sent[1:-1], cand_new_hyp_score))
                else:
                    new_hypotheses.append(new_hyp_sent)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(cand_new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = torch.tensor(live_hyp_ids,
                                        dtype=torch.long,
                                        device=self.device)
            dec_hidden_tm1 = dec_hidden_t[live_hyp_ids]
            ctxt_tm1 = ctxt_t[live_hyp_ids]

            hypotheses = new_hypotheses
            hyp_scores = torch.tensor(new_hyp_scores,
                                      dtype=torch.float,
                                      device=self.device)

        has_comp = True
        if len(completed_hypotheses) == 0:
            has_comp = False
            completed_hypotheses.append(
                (hypotheses[0][1:], hyp_scores[0].item()))

        completed_hypotheses.sort(key=lambda hyp: hyp[1], reverse=True)

        return completed_hypotheses, has_comp

    @property
    def device(self):
        """ Determine which device to place the Tensors upon, CPU or GPU.
        """
        return self.decoder_init_hidden_proj.weight.device

    @staticmethod
    def load(model_path: str):
        """ Load the model from a file.
        @param model_path (str): path to model
        """
        params = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
        args = params['args']
        word_vocab = Vocab.load(params['word_vocab'])
        bio_vocab = Vocab.load(params['bio_vocab'])
        feat_vocab = Vocab.load(params['feat_vocab'])
        model = NMT(word_vocab, bio_vocab, feat_vocab, **args)
        model.load_state_dict(params['state_dict'])

        return model

    def save(self, path: str):
        """ Save the odel to a file.
        @param path (str): path to the model
        """
        print('save model parameters to [%s]' % path, file=sys.stderr)

        params = {'args': self.args, 'state_dict': self.state_dict()}
        params['word_vocab'] = self.word_vocab.state_dict()
        params['bio_vocab'] = self.bio_vocab.state_dict()
        params['feat_vocab'] = self.feat_vocab.state_dict()

        torch.save(params, path)
コード例 #7
0
 def __init__(self, options):
     super(Generator_deTree, self).__init__()
     self.node = TreeNode(options['TreeNode'])
     self.Decoder = Decoder(options['Decoder'])
コード例 #8
0
 def __init__(self, options):
     super(Generator_Flat, self).__init__()
     self.Decoder = Decoder(options['Decoder'])
コード例 #9
0
class QGModel(nn.Module):
    def __init__(self, word_vocab: Vocab, bio_vocab: Vocab, feat_vocab: Vocab,
                 albert: bool, word_embed_size, bio_embed_size,
                 feat_embed_size, hidden_size, dropout, enc_bidir, n_head,
                 max_out_cpy: bool, **kwargs):
        super(QGModel, self).__init__()
        self.word_vocab = word_vocab
        self.bio_vocab = bio_vocab
        self.feat_vocab = feat_vocab
        self.args = {
            'albert': albert,
            'word_embed_size': word_embed_size,
            'bio_embed_size': bio_embed_size,
            'feat_embed_size': feat_embed_size,
            'hidden_size': hidden_size,
            'dropout': dropout,
            'enc_bidir': enc_bidir,
            'n_head': n_head,
            'max_out_cpy': max_out_cpy
        }
        self.args.update(kwargs)
        if albert:
            self.embedding = AlbertFeatureRichEmbedding(
                kwargs['albert_model_name'], len(bio_vocab), bio_embed_size,
                len(feat_vocab), feat_embed_size, kwargs['albert_cache_dir'])
            decoder_word_embed_size = kwargs['albert_word_embed_size']
        else:
            self.embedding = FeatureRichEmbedding(len(word_vocab),
                                                  word_embed_size,
                                                  len(bio_vocab),
                                                  bio_embed_size,
                                                  len(feat_vocab),
                                                  feat_embed_size)
            decoder_word_embed_size = word_embed_size
        self.encoder = Encoder(
            word_embed_size + bio_embed_size + feat_embed_size * 3,
            word_embed_size, hidden_size, dropout, enc_bidir, n_head)
        self.decoder = Decoder(decoder_word_embed_size, hidden_size,
                               hidden_size, len(word_vocab), dropout,
                               max_out_cpy)

    def batch_to_tensor(self, batch: SquadBatch):
        src_indexes = [
            torch.tensor(x, dtype=torch.long, device=self.device)
            for x in (batch.src_index, batch.bio_index, batch.case_index,
                      batch.ner_index, batch.pos_index)
        ]
        src_ext_index = torch.tensor(batch.src_extended_index,
                                     dtype=torch.long,
                                     device=self.device)
        src_len = torch.tensor(batch.src_len,
                               dtype=torch.int,
                               device=self.device)
        src_mask = self.generate_mask(src_len, src_indexes[0].size(1))
        tgt_index = torch.tensor(batch.tgt_index,
                                 dtype=torch.long,
                                 device=self.device).transpose(0, 1)
        ans_len = torch.tensor(batch.ans_len,
                               dtype=torch.int,
                               device=self.device)
        ans_index = torch.tensor(batch.ans_index,
                                 dtype=torch.long,
                                 device=self.device)
        ans_mask = self.generate_mask(ans_len, ans_index.size(1))
        return src_indexes, src_ext_index, src_len, src_mask, tgt_index, ans_len, ans_index, ans_mask

    def encode(self, batch):
        src_indexes, src_ext_index, src_len, src_mask, tgt_index, ans_len, ans_index, ans_mask = self.batch_to_tensor(
            batch)
        src_embed, ans_embed = self.embedding(
            src_indexes[0], src_mask, src_len, ans_index, ans_mask, ans_len,
            src_indexes[1], *src_indexes[2:])  # (B, src_len, embed_size)
        tgt_embed = self.embedding.word_embeddings(
            tgt_index)  # (tgt_len, B, embed_size)
        memory, dec_init_hidden = self.encoder(src_embed, src_mask, src_len,
                                               ans_embed)
        # dec_init_hidden: (B, hidden)
        return memory, dec_init_hidden, tgt_embed, src_mask, src_ext_index

    def forward(self, batch: SquadBatch):
        memory, dec_init_hidden, tgt_embed, src_mask, src_ext_index = self.encode(
            batch)
        logger.debug("memory mask shape {}".format(src_mask))
        logger.debug("memory shape {}".format(memory.shape))

        gen_output, atten_output = self.decoder(memory, src_mask,
                                                src_ext_index, tgt_embed,
                                                dec_init_hidden)
        # (tgt_len - 1, B, vocab + oov), not probability
        return gen_output

    def generate_mask(self, length, max_length):
        mask = torch.zeros(length.size(0), max_length, device=self.device)
        for i, x in enumerate(length):
            mask[i, :x] = 1
        return mask

    def beam_search(self, batch: SquadBatch, beam_size, max_decoding_step):
        """
        :param batch: batch size is 1
        :param beam_size:
        :return:
        """
        oov_word = batch.oov_word[0]
        memory, dec_init_hidden, tgt_embed, src_mask, src_ext_index = self.encode(
            batch)
        # memory: (B, src_len, hidden)
        no_copy_hypothesis = [[self.word_vocab.SOS]]
        copy_hypothesis = [[self.word_vocab.SOS]]
        atten_engy = [[]]
        hyp_scores = torch.zeros(len(copy_hypothesis),
                                 dtype=torch.float,
                                 device=self.device)
        completed_hypothesis = []
        t = 0
        ctxt_tm1 = torch.zeros(len(copy_hypothesis),
                               self.args['hidden_size'],
                               device=self.device)
        dec_hidden_tm1 = dec_init_hidden
        memory_for_attn = self.decoder.memory_attn_proj(memory)
        while len(completed_hypothesis) < beam_size and t < max_decoding_step:
            t += 1
            hyp_num = len(copy_hypothesis)
            prev_word = [x[-1] for x in copy_hypothesis]
            tgt_tm1 = self.embedding.word_embeddings(
                torch.tensor(self.word_vocab.index(prev_word),
                             dtype=torch.long,
                             device=self.device))  # (B, word_embed_size)

            memory_for_attn_tm1 = memory_for_attn.expand(
                (hyp_num, *memory_for_attn.shape[1:]))
            memory_tm1 = memory.expand((hyp_num, *memory.shape[1:]))
            gen_t, dec_hidden_t, ctxt_t, atten_engy_t = self.decoder.decode_step(
                tgt_tm1, ctxt_tm1, dec_hidden_tm1, memory_for_attn_tm1,
                memory_tm1, src_ext_index)
            gen_t = torch.log_softmax(gen_t, dim=-1)  # (B, vocab)
            live_hyp_num = beam_size - len(completed_hypothesis)
            continuating_hyp_scores = (
                hyp_scores.unsqueeze(1).expand_as(gen_t) + gen_t).view(
                    -1)  # (hyp_num * V)
            top_candi_scores, top_candi_position = torch.topk(
                continuating_hyp_scores, k=live_hyp_num)
            prev_hyp_indexes = top_candi_position / gen_t.shape[-1]
            hyp_word_indexes = top_candi_position % gen_t.shape[-1]

            new_copy_hypothesis = []
            new_no_copy_hypothesis = []
            new_atten_engy = []
            live_hyp_index = []
            new_hyp_scores = []
            num_unk = 0
            for prev_hyp_index, hyp_word_index, new_hyp_score in zip(
                    prev_hyp_indexes, hyp_word_indexes, top_candi_scores):
                prev_hyp_index = prev_hyp_index.item()
                hyp_word_index = hyp_word_index.item()
                new_hyp_score = new_hyp_score.item()

                if hyp_word_index < len(self.word_vocab):
                    hyp_word = self.word_vocab.id2word[hyp_word_index]
                    copy_new_hypo = copy_hypothesis[prev_hyp_index] + [
                        hyp_word
                    ]
                    no_copy_new_hypo = no_copy_hypothesis[prev_hyp_index] + [
                        hyp_word
                    ]
                else:
                    hyp_word = oov_word[hyp_word_index - len(self.word_vocab)]
                    copy_new_hypo = copy_hypothesis[prev_hyp_index] + [
                        hyp_word
                    ]
                    no_copy_new_hypo = no_copy_hypothesis[prev_hyp_index] + [
                        '[COPY]'
                    ]
                new_atten_hypo = atten_engy[prev_hyp_index] + [
                    atten_engy_t[prev_hyp_index, :]
                ]
                if hyp_word == self.word_vocab.EOS:
                    completed_hypothesis.append(
                        (copy_new_hypo[1:-1], no_copy_new_hypo[1:-1],
                         torch.stack(new_atten_hypo[:-1]).tolist(),
                         new_hyp_score))
                else:
                    new_copy_hypothesis.append(copy_new_hypo)
                    new_no_copy_hypothesis.append(no_copy_new_hypo)
                    new_atten_engy.append(new_atten_hypo)
                    live_hyp_index.append(prev_hyp_index)
                    new_hyp_scores.append(new_hyp_score)
            if len(completed_hypothesis) == beam_size:
                break
            live_hyp_index = torch.tensor(live_hyp_index,
                                          dtype=torch.long,
                                          device=self.device)
            dec_hidden_tm1 = dec_hidden_t[live_hyp_index]
            ctxt_tm1 = ctxt_t[live_hyp_index]

            copy_hypothesis = new_copy_hypothesis
            no_copy_hypothesis = new_no_copy_hypothesis
            atten_engy = new_atten_engy
            hyp_scores = torch.tensor(new_hyp_scores,
                                      dtype=torch.float,
                                      device=self.device)
        has_completed = True
        if len(completed_hypothesis) == 0:
            has_completed = False
            completed_hypothesis.append(
                (copy_hypothesis[0][1:], no_copy_hypothesis[0][1:],
                 torch.stack(atten_engy[0]).tolist(), hyp_scores[0].item()))
        completed_hypothesis.sort(key=lambda x: x[3], reverse=True)
        return completed_hypothesis, has_completed

    def nucleus_sampling(self,
                         batch: SquadBatch,
                         max_decoding_step,
                         nucleus_p=0.9):
        """
        :param batch: batch size is 1
        :param beam_size:
        :return:
        """
        oov_word = batch.oov_word[0]
        memory, dec_init_hidden, tgt_embed, src_mask, src_ext_index = self.encode(
            batch)
        # memory: (B, src_len, hidden)
        copy_hypothesis = [[self.word_vocab.SOS]]
        no_copy_hypothesis = [[self.word_vocab.SOS]]
        hyp_score = [0]
        has_completed = False
        t = 0
        ctxt_tm1 = torch.zeros(len(copy_hypothesis),
                               self.args['hidden_size'],
                               device=self.device)
        dec_hidden_tm1 = dec_init_hidden
        memory_for_attn = self.decoder.memory_attn_proj(memory)
        while t < max_decoding_step:
            t += 1
            hyp_num = len(copy_hypothesis)
            prev_word = [x[-1] for x in copy_hypothesis]
            tgt_tm1 = self.embedding.word_embeddings(
                torch.tensor(self.word_vocab.index(prev_word),
                             dtype=torch.long,
                             device=self.device))  # (B, word_embed_size)

            memory_for_attn_tm1 = memory_for_attn.expand(
                (hyp_num, *memory_for_attn.shape[1:]))
            memory_tm1 = memory.expand((hyp_num, *memory.shape[1:]))
            gen_t, dec_hidden_t, ctxt_t = self.decoder.decode_step(
                tgt_tm1, ctxt_tm1, dec_hidden_tm1, memory_for_attn_tm1,
                memory_tm1, src_ext_index)
            sorted, sorted_indexes = torch.sort(gen_t, dim=-1, descending=True)
            sorted_p = torch.softmax(sorted, dim=-1)
            cum_p = torch.cumsum(sorted_p, dim=-1)
            is_greater = torch.gt(cum_p, nucleus_p)  # (B, V+extended)
            le_index = torch.min(is_greater, dim=-1)[1].item()
            new_prob = cum_p / cum_p[0, le_index]
            random_v = torch.rand(1).item()
            sampled_index = torch.gt(new_prob, random_v).min(-1)[1].item()

            hyp_word_score = sorted_p[0, sampled_index].item()
            hyp_word_index = sorted_indexes[0, sampled_index].item()
            if hyp_word_index < len(self.word_vocab):
                hyp_word = self.word_vocab.id2word[hyp_word_index]
                copy_hypothesis[0].append(hyp_word)
                no_copy_hypothesis[0].append((hyp_word))
            else:
                hyp_word = oov_word[hyp_word_index - len(self.word_vocab)]
                copy_hypothesis[0].append(hyp_word)
                no_copy_hypothesis[0].append('[COPY]')
            if hyp_word == self.word_vocab.EOS:
                has_completed = True
                break
            hyp_score[0] += hyp_word_score
            ctxt_tm1 = ctxt_t
            dec_hidden_tm1 = dec_hidden_t
        if has_completed:
            return [(copy_hypothesis[0][1:-1], no_copy_hypothesis[0][1:-1],
                     hyp_score[0])], has_completed
        else:
            return [(copy_hypothesis[0][1:], no_copy_hypothesis[0][1:],
                     hyp_score[0])], has_completed

    @property
    def device(self):
        return self.decoder.memory_attn_proj.weight.device

    def save(self, path):
        directory = Path(path).parent
        directory.mkdir(parents=True, exist_ok=True)
        state_dict = {
            'word_vocab': self.word_vocab.state_dict(),
            'bio_vocab': self.bio_vocab.state_dict(),
            'feat_vocab': self.feat_vocab.state_dict(),
            'args': self.args,
            'model_state': self.state_dict()
        }
        torch.save(state_dict, path)

    @staticmethod
    def load(path, device):
        params = torch.load(path, map_location=lambda storage, loc: storage)
        if params['args']['albert']:
            word_vocab = AlbertVocab.load(params['word_vocab'])
        else:
            word_vocab = Vocab.load(params['word_vocab'])
        bio_vocab = Vocab.load(params['bio_vocab'])
        feat_vocab = Vocab.load(params['feat_vocab'])
        model = QGModel(word_vocab, bio_vocab, feat_vocab,
                        **params['args'])  # type:nn.Module
        model.load_state_dict(params['model_state'])
        return model.to(device)
コード例 #10
0
class QGModel(nn.Module):
    def __init__(self,
                 vocab,
                 embed_size,
                 hidden_size,
                 enc_bidir,
                 attn_size,
                 dropout=0.2):
        super(QGModel, self).__init__()
        self.vocab = vocab
        self.args = {
            'embed_size': embed_size,
            'hidden_size': hidden_size,
            'dropout': dropout,
            'enc_bidir': enc_bidir,
            'attn_size': attn_size
        }
        self.embeddings = ModelEmbeddings(embed_size, vocab)
        self.encoder = Encoder(embed_size, hidden_size, dropout, enc_bidir)
        self.decoder_init_hidden_proj = nn.Linear(self.encoder.hidden_size,
                                                  hidden_size)
        self.decoder = Decoder(embed_size, hidden_size, attn_size,
                               len(vocab.tgt), dropout)

    def batch_to_tensor(self, source, target):
        # Compute sentence lengths
        source_lengths = [len(s) for s in source]

        # Convert list of lists into tensors
        source_padded = self.vocab.src.to_input_tensor(
            source, device=self.device)  # Tensor: (src_len, b)
        target_padded = self.vocab.tgt.to_input_tensor(
            target, device=self.device)  # Tensor: (tgt_len, b)
        source_mask = self.generate_mask(source_lengths,
                                         source_padded.shape[0])
        return source_padded, target_padded, source_lengths, source_mask

    def forward(self, source: List[List[str]], target: List[List[str]]):
        source_padded, target_padded, source_lengths, source_mask = self.batch_to_tensor(
            source, target)

        source_embedding = self.embeddings.source(
            source_padded)  # (src_len, b, embed_size)
        target_embedding = self.embeddings.target(
            target_padded)  # (tgt_len, B, embed_size)
        memory, last_hidden = self.encoder(source_embedding, source_lengths)
        # last_hidden: (B, hidden)
        memory = memory.transpose(0, 1)  # memory: (B, src_len, hidden)
        dec_init_hidden = torch.tanh(
            self.decoder_init_hidden_proj(last_hidden))
        gen_output = self.decoder(memory, source_mask, target_embedding,
                                  dec_init_hidden)
        # (tgt_len - 1, B, word_vocab_size), not probability
        P = F.log_softmax(gen_output, dim=-1)

        # Zero out, probabilities for which we have nothing in the target text
        target_masks = (target_padded != self.vocab.tgt['<pad>']).float()

        # Compute log probability of generating true target words
        target_gold_words_log_prob = torch.gather(
            P, index=target_padded[1:].unsqueeze(-1),
            dim=-1).squeeze(-1) * target_masks[1:]
        scores = target_gold_words_log_prob.sum(dim=0)
        return scores

    def generate_mask(self, length, max_length):
        mask = torch.zeros(len(length),
                           max_length,
                           dtype=torch.int,
                           device=self.device)
        for i, x in enumerate(length):
            mask[i, x:] = 1
        return mask

    def beam_search(self,
                    src_sent: List[str],
                    beam_size: int = 5,
                    max_decoding_time_step: int = 70):
        """
        :param batch: batch size is 1
        :param beam_size:
        :return:
        """
        src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device)
        src_len = torch.tensor([len(src_sent)],
                               dtype=torch.int,
                               device=self.device)
        source_embedding = self.embeddings.source(
            src_sents_var)  # (src_len, b, embed_size)

        memory, last_hidden = self.encoder(source_embedding, src_len)
        # last_hidden: (B, hidden)
        memory = memory.transpose(0, 1)  # memory: (B, src_len, hidden)
        dec_init_hidden = torch.tanh(
            self.decoder_init_hidden_proj(last_hidden))  # (B, hidden)
        hypotheses = [['<s>']]
        hyp_scores = torch.zeros(len(hypotheses),
                                 dtype=torch.float,
                                 device=self.device)
        completed_hypotheses = []
        t = 0
        ctxt_tm1 = torch.zeros(len(hypotheses),
                               self.args['hidden_size'],
                               device=self.device)
        dec_hidden_tm1 = dec_init_hidden
        while len(completed_hypotheses
                  ) < beam_size and t < max_decoding_time_step:
            t += 1
            hyp_num = len(hypotheses)
            prev_word = torch.tensor(
                [self.vocab.tgt[x[-1]] for x in hypotheses],
                dtype=torch.long,
                device=self.device)
            tgt_tm1 = self.embeddings.target(prev_word)  # (B, word_embed_size)

            memory_tm1 = memory.expand((hyp_num, *memory.shape[1:]))
            gen_t, dec_hidden_t, ctxt_t = self.decoder.decode_step(
                tgt_tm1, ctxt_tm1, dec_hidden_tm1, memory_tm1)
            gen_t = torch.log_softmax(gen_t, dim=-1)  # (B, vocab)
            live_hyp_num = beam_size - len(completed_hypotheses)
            continuating_hyp_scores = (
                hyp_scores.unsqueeze(1).expand_as(gen_t) + gen_t).view(
                    -1)  # (hyp_num * V)
            top_candi_scores, top_candi_position = torch.topk(
                continuating_hyp_scores, k=live_hyp_num)
            prev_hyp_indexes = top_candi_position / len(self.vocab.tgt)
            hyp_word_indexes = top_candi_position % len(self.vocab.tgt)

            new_hypothesis = []
            live_hyp_index = []
            new_hyp_scores = []
            num_unk = 0
            for prev_hyp_index, hyp_word_index, new_hyp_score in zip(
                    prev_hyp_indexes, hyp_word_indexes, top_candi_scores):
                prev_hyp_index = prev_hyp_index.item()
                hyp_word_index = hyp_word_index.item()
                new_hyp_score = new_hyp_score.item()

                hyp_word = self.vocab.tgt.id2word[hyp_word_index]
                new_hypo = hypotheses[prev_hyp_index] + [hyp_word]
                if hyp_word == '</s>':
                    completed_hypotheses.append(
                        Hypothesis(value=new_hypo[1:-1], score=new_hyp_score))
                else:
                    new_hypothesis.append(new_hypo)
                    live_hyp_index.append(prev_hyp_index)
                    new_hyp_scores.append(new_hyp_score)
            if len(completed_hypotheses) == beam_size:
                break
            live_hyp_index = torch.tensor(live_hyp_index,
                                          dtype=torch.long,
                                          device=self.device)
            dec_hidden_tm1 = dec_hidden_tm1[live_hyp_index]
            ctxt_tm1 = ctxt_t[live_hyp_index]

            hypotheses = new_hypothesis
            hyp_scores = torch.tensor(new_hyp_scores,
                                      dtype=torch.float,
                                      device=self.device)

        has_comp = True
        if len(completed_hypotheses) == 0:
            has_comp = False
            completed_hypotheses.append(
                Hypothesis(value=hypotheses[0][1:],
                           score=hyp_scores[0].item()))
        completed_hypotheses.sort(key=lambda x: x.score, reverse=True)
        return completed_hypotheses, has_comp

    @property
    def device(self):
        return self.decoder_init_hidden_proj.weight.device

    def save(self, path):
        path = path + ".qg"
        dir = Path(path).parent
        dir.mkdir(parents=True, exist_ok=True)
        state_dict = {}
        state_dict['vocab'] = self.vocab
        state_dict['args'] = self.args
        state_dict['model_state'] = self.state_dict()
        torch.save(state_dict, path)

    @staticmethod
    def load(path, device):
        params = torch.load(path, map_location=device)

        model = QGModel(vocab=params['vocab'],
                        **params['args'])  # type:nn.Module
        model.load_state_dict(params['model_state'])
        return model.to(device)