예제 #1
0
    def test_transfer_copy_probs(self):

        # not actual probs, but that's ok
        probs = torch.FloatTensor([
            # special base    copy
            [1, 2, 3, 10, 11, 12, 14, 16],
            [4, 5, 6, 20, 21, 22, 24, 26],
        ])

        base_vocab = HardCopyVocab('a b'.split(), num_copy_tokens=3)
        dynamic_vocabs = [
            HardCopyDynamicVocab(base_vocab, 'b a c d e f g'.split(
            )),  # d, e, f, g don't get assigned copy tokens (not enough)
            HardCopyDynamicVocab(
                base_vocab,
                'e f f f a d'.split()),  # e, f, g get assigned copy tokens
        ]

        AttentionDecoderCellOutput._transfer_copy_probs(
            probs, dynamic_vocabs, base_vocab)

        assert_tensor_equal(
            probs,
            [
                [
                    1, 2, 3, 24, 23, 0, 0, 16
                ],  # copy prob for 'c' is not transferred, since it's not in base
                [4, 5, 6, 46, 21, 22, 24, 0
                 ],  # only prob for 'a' gets transferred
            ])
예제 #2
0
 def dynamic_vocabs(self, base_vocab):
     return [
         HardCopyDynamicVocab(base_vocab,
                              'a b c d e'.split()),  # a, b, c, d, e
         HardCopyDynamicVocab(base_vocab,
                              'c c e d z'.split()),  # c, e, d, z
     ]
예제 #3
0
def eval_batch_ret(ex):
    editor_input = edit_model.preprocess(ex)
    train_decoder = edit_model.train_decoder
    encoder_output, enc_loss = edit_model.encoder(editor_input.encoder_input)
    vocab_probs = edit_model.train_decoder.vocab_probs(
        encoder_output, editor_input.train_decoder_input)
    token_list = editor_input.train_decoder_input.target_words.split()
    base_vocab = edit_model.base_vocab
    unk_idx = base_vocab.word2index(base_vocab.UNK)
    idx_lists = []
    for k in range(len(ex)):
        hcdv = HardCopyDynamicVocab(base_vocab, valid_eval[k].input_words,
                                    edit_model.copy_lens)
        # copy_tok_list = [hcdv.word_to_copy_token.get(tok,base_vocab.UNK) for tok in valid_eval[k].input_words[6]]
        ############
        copy_tok_list = [
            hcdv.word_to_copy_token.get(tok, base_vocab.UNK)
            for tok in valid_eval[k].input_words[-1]
        ]  # here we retrieve (y') from valid_example which we want to retrieve
        ############
        copy_tok_id = [hcdv.word2index(tok) for tok in copy_tok_list]
        idx_lists.append(copy_tok_id)
    ret_mix_pr = 0.0
    all_ranks = [[] for _ in range(len(ex))]
    all_ranks_ret = [[] for _ in range(len(ex))]
    position = 0
    for token, vout in zip(token_list, vocab_probs):
        target_idx = token.values.data.cpu().numpy()
        target_mask = token.mask.data.cpu().numpy()
        in_vocab_id = target_idx[:, 0]
        copy_token_id = target_idx[:, 1]
        vocab_matrix = vout.data.cpu().numpy()
        for i in range(len(in_vocab_id)):
            voc_vec = vocab_matrix[i, :].copy()
            voc_vec_rest = voc_vec.copy()
            voc_vec_rest[copy_token_id[i]] = 0
            voc_vec_rest[in_vocab_id[i]] = 0
            if position < len(idx_lists[i]):
                direct_copy_idx = idx_lists[i][position]
                voc_vec = voc_vec * (1 - ret_mix_pr)
                voc_vec[direct_copy_idx] += ret_mix_pr
            if in_vocab_id[i] == unk_idx:
                gold_rank = np.sum(voc_vec_rest >= voc_vec[copy_token_id[i]])
            else:
                gold_rank = np.sum(voc_vec_rest >= voc_vec[copy_token_id[i]] +
                                   voc_vec[in_vocab_id[i]])
            if target_mask[i] == 1.0:
                all_ranks[i].append(gold_rank)
                all_ranks_ret[i].append(
                    100 * (1.0 - (direct_copy_idx == copy_token_id[i])))
        position += 1
    del token_list
    del vocab_probs
    return all_ranks, all_ranks_ret
예제 #4
0
    def _compute_dynamic_vocabs(self, input_batches, vocab):
        """Compute dynamic vocabs for each example.

        Args:
            input_batches (list[list[list[unicode]]]): a batch of input lists,
                where each input list is a list of sentences
            vocab (HardCopyVocab)

        Returns:
            list[HardCopyDynamicVocab]: a batch of dynamic vocabs, one for each example
        """
        dynamic_vocabs = []
        for input_words in input_batches:
            # compute dynamic vocab from concatenation of input sequences
            #concat = flatten(input_words)
            #dynamic_vocabs.append(HardCopyDynamicVocab(vocab, concat))
            dynamic_vocabs.append(HardCopyDynamicVocab(vocab, input_words, self.copy_lens))
        return dynamic_vocabs
예제 #5
0
 def test_too_many_copy_tokens(self, base_vocab):
     vocab = HardCopyDynamicVocab(base_vocab, 'The'.split())
     # should only use one copy token
     assert vocab.word_to_copy_token == {
         'the': '<copy0>',
     }
예제 #6
0
 def vocab(self, base_vocab):
     return HardCopyDynamicVocab(
         base_vocab, 'apple The bat is the BEST time ever'.split())