Ejemplo n.º 1
0
    def decode(self, 
               trg: torch.Tensor, 
               trg_mask: torch.Tensor,
               memory: torch.Tensor,
               src_mask: torch.Tensor) -> torch.Tensor:

        """
        Function that encodes the target sequence.

        :param trg: Our vectorized target sequence. [Batch_size x trg_seq_len]
        :param trg_mask: Mask to be passed to the SelfAttention Block when encoding 
                the trg sequence-> check SelfAttention.
        :param memory: Our src sequence encoded by the encoding function. This will be used
                as a memory over the source sequence.
        :param src_mask: Mask to be passed to the SelfAttention Block when computing attention
                over the source sequence/memory. -> check SelfAttention.

        :returns:
            -  Returns the log probabilities of the next word over the entire target sequence.
                [Batch_size x trg_seq_len x vocab_size]
        """
        tokens = self.token_embedding(trg)
        b, t, e = tokens.size()
            
        positions = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b, t, e)
        x = tokens + positions

        trg_mask = trg_mask & subsequent_mask(t).type_as(trg_mask)
        for block in self.decoding_blocks:
            x = block(x, memory, src_mask, trg_mask)

        x = self.toprobs(x.view(b*t, e)).view(b, t, self.vocab_size)
        return F.log_softmax(x, dim=2)
Ejemplo n.º 2
0
def init_vars(device, src, src_mask, utter_type, model, tokenizer):

    start_index = tokenizer.cls_token_id
    mem, utter_mask, token_features, token_mask = model.encode(
        src, src_mask, utter_type)
    ys = torch.ones(1, 1).fill_(start_index).long().to(device)
    trg_mask = subsequent_mask(ys.size(1)).type_as(ys)

    coverage = torch.zeros_like(utter_mask).contiguous().float()
    token_coverage = torch.zeros_like(token_mask).contiguous().float()
    vocab_dist, tgt_attn_dist, p_gen, next_cov, next_tok_cov, tok_utter_index = model.decode(
        ys, trg_mask, mem, utter_mask, token_features, token_mask, coverage,
        token_coverage)
    index_1 = torch.arange(0, Config.batch_size).long()
    if Config.pointer_gen:
        vocab_dist_ = p_gen * vocab_dist
        attn_dist_ = (1 - p_gen) * tgt_attn_dist
        token_attn_indices = src[tok_utter_index, index_1, :]
        final_dist = vocab_dist_.scatter_add(1, token_attn_indices, attn_dist_)
    else:
        final_dist = vocab_dist
    final_dist = torch.log(final_dist)
    log_scores, ix = final_dist.topk(Config.beam_size)
    outputs = torch.zeros(Config.beam_size,
                          Config.max_decode_output_length).long()

    outputs = outputs.to(device)
    outputs[:, 0] = start_index
    outputs[:, 1] = ix[0]
    e_outputs = torch.zeros(Config.beam_size, mem.size(-2), mem.size(-1))
    e_outputs = e_outputs.to(device)
    e_outputs[:, :] = mem[0]
    return outputs, e_outputs, log_scores, utter_mask, next_cov, next_tok_cov, mem, utter_mask, token_features, token_mask
Ejemplo n.º 3
0
    def forward(self, batch_data):
        src = batch_data['src']
        src_mask = batch_data['src_mask']
        rel_ids = batch_data['rel_ids']
        extra_vocab = batch_data['extra_vocab']
        expanded_x = batch_data['expanded_x']
        semantic_mask = batch_data['semantic_mask']
        batch_size = src.size(0)

        memory = self.model.encode(src, src_mask, rel_ids)

        ys = torch.ones(batch_size, 1).fill_(self.start_pos).type_as(src.data)
        for i in range(self.max_tgt_len - 1):
            tgt_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data))
            tgt_mask = tgt_mask.unsqueeze(1)
            tgt_embed = self.model.tgt_embed(Variable(ys))
            decoder_outputs, decoder_attn = self.model.decode(
                memory, src_mask, tgt_embed, tgt_mask)

            prob = self.model.generator(
                (memory, decoder_outputs[:, -1].unsqueeze(1),
                 decoder_attn[:, :,
                              -1].unsqueeze(2), tgt_embed[:, -1].unsqueeze(1),
                 extra_vocab, expanded_x, semantic_mask))
            prob = prob.squeeze(1)
            _, next_word = torch.max(prob, dim=1)
            ys = torch.cat([ys, next_word.unsqueeze(1).type_as(src.data)],
                           dim=1)
        return ys, extra_vocab
Ejemplo n.º 4
0
 def make_std_mask(tgt, pad):
     "Create a mask to hide padding and future words."
     # noinspection PyUnresolvedReferences
     tgt_mask = (tgt != pad).unsqueeze(-2)
     tgt_mask = tgt_mask & Variable(
         subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
     return tgt_mask
Ejemplo n.º 5
0
def get_batch_data(src,
                   tgt,
                   pad,
                   rel_ids=[],
                   extra_vocab=dict(),
                   expanded_x=[],
                   semantic_mask=[]):

    src_mask = (src != pad).unsqueeze(-2).unsqueeze(1)

    tgt_mask = (tgt != pad).unsqueeze(-2).unsqueeze(1)
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))

    batch = {
        'src': src,
        'tgt': tgt,
        'src_mask': src_mask,
        'tgt_mask': tgt_mask,
        'rel_ids': rel_ids,
        'extra_vocab': extra_vocab,
        'expanded_x': expanded_x,
        'semantic_mask': semantic_mask
    }

    return batch
Ejemplo n.º 6
0
    def greedy_decode(self, seq, attn_masks, max_len, vocab):
        start_symbol = vocab.token_to_idx['[CLS]']
        # end_symbol = vocab.token_to_idx['[SEP]']

        # required attn_masks shape for BERT: [batch_size,sequence_length] (zzingae)
        cont_reps, _ = self.bert(seq, attention_mask=attn_masks)
        # required attn_masks shape for decoder: [batch_size,1,sequence_length] (zzingae)
        attn_masks = attn_masks.unsqueeze(1)

        ys = torch.ones(seq.shape[0], 1).fill_(start_symbol).type_as(seq.data)
        for i in range(max_len):
            tgt_mask = subsequent_mask(ys.shape[1]).repeat(seq.shape[0], 1,
                                                           1).type_as(seq.data)
            output = self.decoder(self.tgt_embed(ys), cont_reps, attn_masks,
                                  tgt_mask)
            log_prob = self.generator(output[:, -1, :])

            _, next_word = torch.max(log_prob, dim=1)

            # ys = torch.cat([ys, torch.ones(seq.shape[0], 1).type_as(seq.data).fill_(next_word[0,i])], dim=1)
            ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            if i == (max_len - 1):
                logits = self.generator(output)

        return logits, ys
Ejemplo n.º 7
0
def greedy(args, net, src, src_mask, src_vocab, tgt_vocab):
    # src: torch.LongTensor (bsz, slen)
    # src_mask: torch.ByteTensor (bsz, 1, slen)
    slen = src.size(1)
    max_len = min(args.max_len, int(args.gen_a * slen + args.gen_b))
    src_lan = src_vocab.stoi["<" + args.src_lan.upper() + ">"]
    tgt_lan = tgt_vocab.stoi["<" + args.tgt_lan.upper() + ">"]
    with torch.no_grad():
        net.eval()
        bsz = src.size(0)
        enc_out = net.encode(src=src, src_mask=src_mask, src_lang=src_lan)
        generated = src.new(bsz, max_len)
        generated.fill_(tgt_vocab.stoi[config.PAD])
        generated[:, 0].fill_(tgt_vocab.stoi[config.BOS])
        generated = generated.long()

        cur_len = 1
        gen_len = src.new_ones(bsz).long()
        unfinished_sents = src.new_ones(bsz).long()

        cache = {'cur_len': cur_len - 1}

        while cur_len < max_len:
            x = generated[:, cur_len - 1].unsqueeze(-1)
            tgt_mask = (generated[:, :cur_len] !=
                        tgt_vocab.stoi[config.PAD]).unsqueeze(-2)
            tgt_mask = tgt_mask & Variable(
                subsequent_mask(cur_len).type_as(tgt_mask.data))

            logit = net.decode(
                enc_out,
                src_mask,
                x,
                tgt_mask[:, cur_len - 1, :].unsqueeze(-2),
                cache=cache,
                tgt_lang=tgt_lan,
                decoder_lang_id=config.LANG2IDS[args.tgt_lan.upper()])
            scores = net.generator(logit).exp().squeeze()

            next_words = torch.topk(scores, 1)[1].squeeze()

            assert next_words.size() == (bsz, )
            generated[:,
                      cur_len] = next_words * unfinished_sents + tgt_vocab.stoi[
                          config.PAD] * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(
                next_words.ne(tgt_vocab.stoi[config.EOS]).long())
            cur_len = cur_len + 1
            cache['cur_len'] = cur_len - 1

            if unfinished_sents.max() == 0:
                break

        if cur_len == max_len:
            generated[:, -1].masked_fill_(unfinished_sents.bool(),
                                          tgt_vocab.stoi[config.EOS])

        return generated[:, :cur_len], gen_len
Ejemplo n.º 8
0
    def make_std_mask(data, pad):
        """
        Create a mask to hide padding and future words.
        """
        data_mask = (data != pad).unsqueeze(-2)
        data_mask = data_mask & Variable(
            subsequent_mask(data.size(-1)).type_as(data_mask.data))

        return data_mask
    def make_std_mask(tgt, pad):
        """
        Create a mask to hide padding and future words.
        """
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))

        return tgt_mask
Ejemplo n.º 10
0
def test(model, config):
    max_sentences = config.get("max_sentences", 1e9)
    max_tokens = config.get("max_tokens", 1e9)

    corpus_prefix = Path(config['corpus_prefix']) / "subword"
    model_path = corpus_prefix / "spm.model"
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(str(model_path))
    test_src = load_corpus(corpus_prefix / Path(config["test_source"]).name,
                           tokenizer)
    num_test_sents = len(test_src)

    eos_id = tokenizer.eos_id()
    test_ids = list(range(num_test_sents))

    test_itr = create_batch_itr(test_src,
                                max_tokens=max_tokens,
                                max_sentences=max_sentences,
                                shuffle=False)
    test_itr = tqdm(test_itr, desc='test')
    for batch_ids in test_itr:
        src_batch = make_batch(test_src, batch_ids, eos_id)
        src_mask = padding_mask(src_batch, eos_id)
        src_encode = model.encode(src_batch, src_mask, train=False)

        trg_ids = [np.array([tokenizer.PieceToId('<s>')] * len(batch_ids))]
        eos_ids = np.array([eos_id] * len(batch_ids))
        while (trg_ids[-1] != eos_ids).any():
            if len(trg_ids) > config['generation_limit']:
                print("Warning: Sentence generation did not finish in",
                      config['generation_limit'],
                      "iterations.",
                      file=sys.stderr)
                trg_ids.append(eos_ids)
                break

            trg_mask = [
                subsequent_mask(len(trg_ids))
                for _ in padding_mask(trg_ids, eos_id)
            ]
            out = model.decode(src_encode,
                               trg_ids,
                               src_mask,
                               trg_mask,
                               train=False)
            y = TF.pick(out, [out.shape()[0] - 1], 0)
            y = np.array(y.argmax(1))
            trg_ids.append(y)

        hyp = [
            hyp_sent[:np.where(hyp_sent == eos_id)[0][0]]
            for hyp_sent in np.array(trg_ids).T
        ]
        for ids in hyp:
            sent = tokenizer.DecodeIds(ids.tolist())
            print(sent)
Ejemplo n.º 11
0
def beam_search(src, net, src_vocab, tgt_vocab, args):
    """
    This implementation of beam search is problematic.
    log_scores should not include scores of special tokens like <pad>, <s> etc.
    """
    outputs, e_outputs, log_scores = init_vars(src, net, src_vocab, tgt_vocab,
                                               args)
    eos_tok = tgt_vocab.stoi[config.EOS]
    src_mask = (src != src_vocab.stoi[config.PAD]).unsqueeze(-2)
    ind = None

    src_len = src.size(-1)
    max_len = int(min(args.max_len, args.max_ratio * src_len))
    for i in range(1, max_len):

        tgt_mask = subsequent_mask(i)
        if args.use_cuda:
            tgt_mask = tgt_mask.cuda()

        out = net.generator.proj(
            net.decode(e_outputs, src_mask, outputs[:, :i], tgt_mask))

        out = F.softmax(out, dim=-1)

        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i,
                                             args.k)

        ones = (outputs == eos_tok).nonzero(
        )  # Occurrences of end symbols for all input sentences.
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
        for vec in ones:
            i = vec[0]
            if sentence_lengths[
                    i] == 0:  # First end symbol has not been found yet
                sentence_lengths[i] = vec[1]  # Position of first end symbol

        num_finished_sentences = len([s for s in sentence_lengths if s > 0])

        if num_finished_sentences == args.k:
            alpha = args.length_penalty
            div = 1 / (sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break

    if ind is None:
        try:
            length = (outputs[0] == eos_tok).nonzero()[0]
            return ' '.join(
                [tgt_vocab.itos[tok] for tok in outputs[0][1:length]])
        except IndexError:
            return ' '.join([tgt_vocab.itos[tok] for tok in outputs[0]])
    else:
        length = (outputs[ind] == eos_tok).nonzero()[0]
        return ' '.join(
            [tgt_vocab.itos[tok] for tok in outputs[ind][1:length]])
Ejemplo n.º 12
0
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys
Ejemplo n.º 13
0
    def make_std_mask(tgt, pad):
        # pad_tgt_mask.size(): (batch_size, 1, max_len-1)
        # pad_tgt_mask is for [padding] mask in each sentence
        pad_tgt_mask = (tgt != pad).unsqueeze(-2)
        length = tgt.size(-1)
        # sub_tgt_mask.size(): (max_len-1, max_len-1)
        # sub_tgt_mask is for [future word] mask
        sub_tgt_mask = subsequent_mask(length).type_as(pad_tgt_mask.data)

        # total_tgt_mask.size(): (batch_size, max_len-1, max_len-1)
        # total_tgt_mask is for padding and future words mask
        total_tgt_mask = pad_tgt_mask & sub_tgt_mask
        return total_tgt_mask
Ejemplo n.º 14
0
    def make_mask(target):
        """Create a target mask.

    The mask shape is [batch_size, num_steps, sequence_length].
    The mask blocks out both PADs and access to next tokens.

    Args:
      target: (torch.Tensor) [batch_size, sequence_length].

    Returns:
      target_mask: (torch.Tensor) [batch_size, num_steps, sequence_length].
    """
        target_mask = (target != constants.PAD).unsqueeze(-2)
        target_mask = target_mask & Var(
            utils.subsequent_mask(target.size(-1)).type_as(target_mask.data))
        return target_mask
Ejemplo n.º 15
0
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1, dtype=torch.int64).fill_(start_symbol)
    for i in range(max_len - 1):
        out = model.decode(
            memory.to(device), src_mask,
            Variable(ys).to(device),
            Variable(subsequent_mask(ys.size(1)).type(
                torch.LongTensor)).to(device))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        if next_word == 2:
            break
        ys = torch.cat(
            [ys, torch.ones(1, 1, dtype=torch.int64).fill_(next_word)], dim=1)
    return ys
Ejemplo n.º 16
0
def greedy_decode(tree_transformer_model, batch, max_len, start_pos):
    memory, _ = tree_transformer_model.encode(batch.code, batch.par_matrix,
                                              batch.bro_matrix,
                                              batch.re_par_ids,
                                              batch.re_bro_ids)
    ys = torch.ones(1, 1).fill_(start_pos).type_as(batch.code.data)
    for i in range(max_len - 1):
        #  memory, code_mask, comment, comment_mask
        out, _, _ = tree_transformer_model.decode(
            memory, batch.code_mask, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(batch.code.data)))
        prob = tree_transformer_model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys,
             torch.ones(1, 1).type_as(batch.code.data).fill_(next_word)],
            dim=1)
    return ys
Ejemplo n.º 17
0
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    传入一个训练好的模型,对指定数据进行预测
    """
    # 先用encoder进行encode
    memory = model.encode(src, src_mask)
    # 初始化预测内容为1×1的tensor,填入开始符('BOS')的id,并将type设置为输入数据类型(LongTensor)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    # 遍历输出的长度下标
    for i in range(max_len - 1):
        # decode得到隐层表示
        out = model.decode(
            memory, src_mask, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        # 将隐藏表示转为对词典各词的log_softmax概率分布表示
        prob = model.generator(out[:, -1])
        # 获取当前位置最大概率的预测词id
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        # 将当前位置预测的字符id与之前的预测内容拼接起来
        ys = torch.cat(
            [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys
Ejemplo n.º 18
0
def init_vars(src, net, src_vocab, tgt_vocab, args):
    # src: torch.LongTensor, (1, src_len)
    assert src.dim() == 2
    src_len = src.size(-1)
    max_len = int(min(args.max_len, args.max_ratio * src_len))

    init_tok = tgt_vocab.stoi[config.BOS]
    src_mask = (src != src_vocab.stoi[config.PAD]).unsqueeze(-2)
    e_output = net.encode(src, src_mask)

    outputs = torch.LongTensor([[init_tok]])
    if args.use_cuda:
        outputs = outputs.cuda()

    tgt_mask = subsequent_mask(1)
    if args.use_cuda:
        tgt_mask = tgt_mask.cuda()

    out = net.generator.proj(net.decode(e_output, src_mask, outputs, tgt_mask))
    out = F.softmax(out, dim=-1)

    probs, ix = out[:, -1].data.topk(args.k)
    log_scores = torch.Tensor([math.log(prob)
                               for prob in probs.data[0]]).unsqueeze(0)

    outputs = torch.zeros(args.k, max_len).long()
    if args.use_cuda:
        outputs = outputs.cuda()
    outputs[:, 0] = init_tok
    outputs[:, 1] = ix[0]

    e_outputs = torch.zeros(args.k, e_output.size(-2), e_output.size(-1))
    if args.use_cuda:
        e_outputs = e_outputs.cuda()
    e_outputs[:, :] = e_output[0]

    return outputs, e_outputs, log_scores
Ejemplo n.º 19
0
def train(model, optimizer, config, best_valid):
    max_epoch = config.get("max_epoch", int(1e9))
    max_iteration = config.get("max_iteration", int(1e9))
    max_sentences = config.get("max_sentences", 1e9)
    max_tokens = config.get("max_tokens", 1e9)
    update_freq = config.get('update_freq', 1)

    optimizer.add(model)

    corpus_prefix = Path(config['corpus_prefix']) / "subword"
    model_path = corpus_prefix / "spm.model"
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(str(model_path))
    train_src = load_corpus(corpus_prefix / Path(config["train_source"]).name,
                            tokenizer)
    train_trg = load_corpus(corpus_prefix / Path(config["train_target"]).name,
                            tokenizer)
    train_src, train_trg = clean_corpus(train_src, train_trg, config)
    dev_src = load_corpus(corpus_prefix / Path(config["dev_source"]).name,
                          tokenizer)
    dev_trg = load_corpus(corpus_prefix / Path(config["dev_target"]).name,
                          tokenizer)
    dev_src, dev_trg = clean_corpus(dev_src, dev_trg, config)
    num_train_sents = len(train_src)
    num_dev_sents = len(dev_src)
    eos_id = tokenizer.eos_id()

    epoch = 0
    iteration = 0
    while epoch < max_epoch and iteration < max_iteration:
        epoch += 1
        g = Graph()
        Graph.set_default(g)

        train_itr = create_batch_itr(train_src,
                                     train_trg,
                                     max_tokens,
                                     max_sentences,
                                     shuffle=True)
        train_itr = tqdm(train_itr, desc='train epoch {}'.format(epoch))

        train_loss = 0.
        itr_loss = 0.
        itr_tokens = 0
        itr_sentences = 0
        optimizer.reset_gradients()
        for step, batch_ids in enumerate(train_itr):
            src_batch = make_batch(train_src, batch_ids, eos_id)
            trg_batch = make_batch(train_trg, batch_ids, eos_id)
            src_mask = padding_mask(src_batch, eos_id)
            trg_mask = [
                x | subsequent_mask(len(trg_batch) - 1)
                for x in padding_mask(trg_batch[:-1], eos_id)
            ]
            itr_tokens += len(src_batch) * len(src_batch[0])
            itr_sentences += len(batch_ids)

            g.clear()
            loss = model.loss(src_batch, trg_batch, src_mask, trg_mask)

            loss /= update_freq
            loss.backward()
            loss_val = loss.to_float()
            train_loss += loss_val * update_freq * len(batch_ids)
            itr_loss += loss_val

            # with open('graph.dot', 'w') as f:
            #     print(g.dump("dot"), end="", file=f)

            if (step + 1) % update_freq == 0:
                step_num = optimizer.get_epoch() + 1
                new_scale = config['d_model'] ** (-0.5) * \
                    min(step_num ** (-0.5), step_num * config['warmup_steps'] ** (-1.5))
                optimizer.set_learning_rate_scaling(new_scale)

                optimizer.update()
                optimizer.reset_gradients()

                iteration += 1
                train_itr.set_postfix(itr=("%d" % (iteration)),
                                      loss=("%.3lf" % (itr_loss)),
                                      wpb=("%d" % (itr_tokens)),
                                      spb=("%d" % (itr_sentences)),
                                      lr=optimizer.get_learning_rate_scaling())
                itr_loss = 0.
                itr_tokens = 0
                itr_sentences = 0

            if iteration >= max_iteration:
                break
        print("\ttrain loss = %.4f" % (train_loss / num_train_sents))

        g.clear()
        valid_loss = 0.
        valid_itr = create_batch_itr(dev_src,
                                     dev_trg,
                                     max_tokens,
                                     max_sentences,
                                     shuffle=False)
        valid_itr = tqdm(valid_itr, desc='valid epoch {}'.format(epoch))
        for batch_ids in valid_itr:
            src_batch = make_batch(dev_src, batch_ids, eos_id)
            trg_batch = make_batch(dev_trg, batch_ids, eos_id)
            src_mask = padding_mask(src_batch, eos_id)
            trg_mask = [
                x | subsequent_mask(len(trg_batch) - 1)
                for x in padding_mask(trg_batch[:-1], eos_id)
            ]

            loss = model.loss(src_batch,
                              trg_batch,
                              src_mask,
                              trg_mask,
                              train=False)
            valid_loss += loss.to_float() * len(batch_ids)
            valid_itr.set_postfix(loss=loss.to_float())
        print("\tvalid loss = %.4f" % (valid_loss / num_dev_sents))

        if valid_loss < best_valid:
            best_valid = valid_loss
            print('\tsaving model/optimizer ... ', end="", flush=True)
            prefix = config['model_prefix']
            model.save(prefix + '.model')
            optimizer.save(prefix + '.optimizer')
            with Path(prefix).with_suffix('.valid').open('w') as f:
                f.write(str(best_valid))
            print('done.')
Ejemplo n.º 20
0
 def make_std_mask(comment, pad):
     comment_mask = (comment != pad).unsqueeze(-2)
     tgt_mask = comment_mask & Variable(
         subsequent_mask(comment.size(-1)).type_as(comment_mask.data))
     return tgt_mask
Ejemplo n.º 21
0
def generate_beam(args, net, src, src_mask, src_vocab, tgt_vocab):
    """
    Decode a sentence given initial start.
    `x`:
        - LongTensor(bs, slen)
            <EOS> W1 W2 W3 <EOS> <PAD>
            <EOS> W1 W2 W3   W4  <EOS>
    `lengths`:
        - LongTensor(bs) [5, 6]
    """

    max_len = args.max_len
    beam_size = args.beam_size
    length_penalty = args.length_penalty
    early_stopping = args.early_stopping

    with torch.no_grad():
        net.eval()
        # batch size / number of words
        n_words = net.generator.n_vocab
        bs = src.size(0)

        # calculate encoder output
        src_enc = net.encode(src=src, src_mask=src_mask)
        src_len = src_mask.view(bs, -1).sum(dim=-1).long()

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(
            1).expand((bs, beam_size) +
                      src_enc.shape[1:]).contiguous().view((bs * beam_size, ) +
                                                           src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(bs,
                                              beam_size).contiguous().view(-1)
        src_mask = src_mask.unsqueeze(1).expand(
            (bs, beam_size) +
            src_mask.shape[1:]).contiguous().view((bs * beam_size, ) +
                                                  src_mask.shape[1:])

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(max_len, bs * beam_size)  # upcoming output
        generated.fill_(
            tgt_vocab.stoi[config.PAD])  # fill upcoming ouput with <PAD>
        generated[0].fill_(
            tgt_vocab.stoi[config.BOS])  # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(beam_size, max_len, length_penalty, early_stopping)
            for _ in range(bs)
        ]

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        cache = {'cur_len': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            x = generated[cur_len - 1].unsqueeze(-1)
            tgt_mask = (generated[:cur_len] !=
                        tgt_vocab.stoi[config.PAD]).transpose(0,
                                                              1).unsqueeze(-2)
            tgt_mask = tgt_mask & Variable(
                subsequent_mask(cur_len).type_as(tgt_mask.data))

            tensor = net.decode(src_enc, src_mask, x,
                                tgt_mask[:,
                                         cur_len - 1, :].unsqueeze(-2), cache)

            assert tensor.size() == (bs * beam_size, 1, config.d_model)
            tensor = tensor.data.view(bs * beam_size,
                                      -1)  # (bs * beam_size, dim)
            scores = net.generator(tensor)  # (bs * beam_size, n_words)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(
                scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size *
                                   n_words)  # (bs, beam_size * n_words)

            next_scores, next_words = torch.topk(_scores,
                                                 2 * beam_size,
                                                 dim=1,
                                                 largest=True,
                                                 sorted=True)
            assert next_scores.size() == next_words.size() == (bs,
                                                               2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[
                    sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend(
                        [(0, tgt_vocab.stoi[config.PAD], 0)] *
                        beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id],
                                      next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == tgt_vocab.stoi[
                            config.EOS] or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(
                            generated[:cur_len,
                                      sent_id * beam_size + beam_id].clone(),
                            value.item())
                    else:
                        next_sent_beam.append(
                            (value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam
                           ) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, tgt_vocab.stoi[config.PAD], 0)
                                      ] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in cache.keys():
                if k != 'cur_len':
                    cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1
            cache['cur_len'] = cur_len - 1

            # stop when we are done with each sentence
            if all(done):
                break

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(),
                              bs).fill_(tgt_vocab.stoi[config.PAD])
        for i, hypo in enumerate(best):
            decoded[:tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = tgt_vocab.stoi[config.PAD]

        return decoded.transpose(0, 1), tgt_len
Ejemplo n.º 22
0
 def _model_decode(self, comment, memory, code_mask):
     comment_mask = subsequent_mask(comment.size(1)).cuda()
     dec_output, _, _ = self.model.decode(memory, code_mask, comment,
                                          comment_mask)
     return self.model.generator(dec_output)
Ejemplo n.º 23
0
def _get_scores(args, net, active_func, src, src_mask, indices, src_vocab,
                tgt_vocab):
    net.eval()
    slen = src.size(1)
    max_len = min(args.max_len, int(slen * args.gen_a + args.gen_b))
    average_by_length = (active_func != "tte")
    result = []
    with torch.no_grad():
        bsz = src.size(0)
        enc_out = net.encode(src=src, src_mask=src_mask)
        generated = src.new(bsz, max_len)
        generated.fill_(tgt_vocab.stoi[config.PAD])
        generated[:, 0].fill_(tgt_vocab.stoi[config.BOS])
        generated = generated.long()

        max_gen_len = src_mask.sum(dim=-1).float() * args.gen_a + args.gen_b
        max_gen_len = max_gen_len.long().view(bsz, )
        max_gen_len = torch.clamp(max_gen_len, min=0, max=max_len)
        cur_len = 1
        gen_len = src.new_ones(bsz).long()
        unfinished_sents = src.new_ones(bsz).long()
        query_scores = src.new_zeros(bsz).float()

        cache = {'cur_len': cur_len - 1}

        while cur_len < max_len:
            x = generated[:, cur_len - 1].unsqueeze(-1)
            tgt_mask = (generated[:, :cur_len] !=
                        tgt_vocab.stoi[config.PAD]).unsqueeze(-2)
            tgt_mask = tgt_mask & Variable(
                subsequent_mask(cur_len).type_as(tgt_mask.data))

            logit = net.decode(enc_out, src_mask, x,
                               tgt_mask[:,
                                        cur_len - 1, :].unsqueeze(-2), cache)
            scores = net.generator(logit).exp().view(bsz, -1).data

            # Calculate activation function value
            # The smaller query score is, the more uncertain model is about the sentence
            if active_func == "lc":
                q_scores, _ = torch.topk(scores, 1, dim=-1)
                q_scores = -(1.0 - q_scores).squeeze()
            elif active_func == "margin":
                q_scores, _ = torch.topk(scores, 2, dim=-1)
                q_scores = q_scores[:, 0] - q_scores[:, 1]
            elif active_func == "te" or active_func == "tte":
                q_scores = -torch.distributions.categorical.Categorical(
                    probs=scores).entropy()
            else:
                q_scores = scores.new_zeros(bsz)
            q_scores = q_scores.view(bsz)
            assert q_scores.size() == (bsz, ), q_scores

            query_scores = query_scores + unfinished_sents.float() * q_scores

            next_words = torch.topk(scores, 1)[1].squeeze()

            next_words = next_words.view(bsz)
            assert next_words.size() == (bsz, )
            generated[:,
                      cur_len] = next_words * unfinished_sents + tgt_vocab.stoi[
                          config.PAD] * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(
                next_words.ne(tgt_vocab.stoi[config.EOS]).long())
            unfinished_sents.mul_(gen_len.le(max_gen_len).long())
            cur_len = cur_len + 1
            cache['cur_len'] = cur_len - 1

            if unfinished_sents.max() == 0:
                break

        if cur_len == max_len:
            generated[:, -1].masked_fill_(unfinished_sents.bool(),
                                          tgt_vocab.stoi[config.EOS])

        translated = gen_batch2str(generated[:, :cur_len], gen_len, tgt_vocab)

        if average_by_length:
            query_scores = query_scores / gen_len.float()
        query_scores = query_scores.cpu().numpy().tolist()
        indices = indices.tolist()
        assert len(query_scores) == len(indices)
        assert len(query_scores) == len(translated)
        for i in range(len(query_scores)):
            result.append((query_scores[i], indices[i], translated[i]))
    return result
Ejemplo n.º 24
0
 def make_std_mask(tgt, pad_token_idx):
     target_mask = (tgt != pad_token_idx).unsqueeze(-2)
     # make look-ahead mask - 하나씩 늘려가면서
     target_mask = target_mask & Variable(
         subsequent_mask(tgt.size(-1)).type_as(target_mask.data))
     return target_mask.squeeze()