Example #1
0
def batch_align(encdec, eos_idx, src_tgt_data, batch_size=80, gpu=None):
    nb_ex = len(src_tgt_data)
    nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0)
    sum_loss = 0
    attn_all = []
    for i in range(nb_batch):
        current_batch_raw_data = src_tgt_data[i * batch_size:(i + 1) *
                                              batch_size]
        #         print current_batch_raw_data
        src_batch, tgt_batch, src_mask, arg_sort = make_batch_src_tgt(
            current_batch_raw_data,
            eos_idx=eos_idx,
            gpu=gpu,
            volatile="on",
            need_arg_sort=True)
        loss, attn_list = encdec(src_batch,
                                 tgt_batch,
                                 src_mask,
                                 keep_attn_values=True)
        deb_attn = de_batch(attn_list,
                            mask=None,
                            eos_idx=None,
                            is_variable=True,
                            raw=True)

        assert len(arg_sort) == len(deb_attn)
        de_sorted_attn = [None] * len(deb_attn)
        for xpos in xrange(len(arg_sort)):
            original_pos = arg_sort[xpos]
            de_sorted_attn[original_pos] = deb_attn[xpos]

        attn_all += de_sorted_attn
        sum_loss += float(loss.data)
    return sum_loss, attn_all
Example #2
0
 def test_multiple_length_variable_raw(self):
     batch = [Variable(np.array(x, dtype=np.int32)) for x in [[1, 3, 4, 8], [1, 5, 6, 9], [[7, 9], [5, 8]], [10]]]
     seq_list = de_batch(batch, is_variable=True, raw=True)
     assert len(seq_list) == 4
     for seq1, seq2 in zip(seq_list, [[1, 1, [7, 9], 10], [3, 5, [5, 8]], [4, 6], [8, 9]]):
         assert len(seq1) == len(seq2)
         for elem1, elem2 in zip(seq1, seq2):
             assert np.all(elem1 == elem2)
Example #3
0
    def test_multiple_length_variable(self):
        batch = [
            Variable(np.array(x, dtype=np.int32))
            for x in [[1, 3, 4, 8], [1, 5, 6, 9], [7, 5], [10]]
        ]
        seq_list = de_batch(batch, is_variable=True)

        assert seq_list == [[1, 1, 7, 10], [3, 5, 5], [4, 6], [8, 9]]
Example #4
0
 def test_multiple_length_eos_idx(self):
     batch = [
         np.array([1, 3, 4, 8]),
         np.array([3, 3, 6, 9]),
         np.array([7, 5]),
         np.array([10])
     ]
     seq_list = de_batch(batch, eos_idx=3)
     assert seq_list == [[1, 3], [3], [4, 6], [8, 9]]
Example #5
0
    def test_multiple_length(self):
        batch = [
            np.array([1, 3, 4, 8]),
            np.array([1, 5, 6, 9]),
            np.array([7, 5]),
            np.array([10])
        ]
        seq_list = de_batch(batch)

        assert seq_list == [[1, 1, 7, 10], [3, 5, 5], [4, 6], [8, 9]]
Example #6
0
 def test_mask3(self):
     batch = [
         np.array([1, 3, 4, 8]),
         np.array([1, 5, 6, 9]),
         np.array([7, 5, 3, 4])
     ]
     mask = [
         np.array([True, True, False, True]),
         np.array([True, True, False, False])
     ]
     seq_list = de_batch(batch, mask=mask)
     assert seq_list == [[1, 1, 7], [3, 5, 5], [4], [8, 9]]
Example #7
0
    def train_once_reinf(src_batch, tgt_batch, src_mask):  # , lexicon_matrix = None):
        t0 = time.clock()
        encdec.zerograds()
        t1 = time.clock()

        from nmt_chainer.utilities import utils
        test_ref = utils.de_batch(tgt_batch, is_variable=True)

        reinf_loss = encdec.get_reinf_loss(src_batch, src_mask, eos_idx,
                                           test_ref, nb_steps=50, nb_samples=5,
                                           use_best_for_sample=False,
                                           temperature=None,
                                           mode="test")

        t2 = time.clock()
        reinf_loss.backward()
        t3 = time.clock()
        optimizer.update()
        t4 = time.clock()
        print "reinf loss:", reinf_loss.data, reinf_loss.data / len(src_batch)
        print " time %f zgrad:%f fwd:%f bwd:%f upd:%f" % (t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3)
        return float(reinf_loss.data), len(src_batch)
Example #8
0
def sample_once(encdec,
                src_batch,
                tgt_batch,
                src_mask,
                src_indexer,
                tgt_indexer,
                eos_idx,
                max_nb=None,
                s_unk_tag="#S_UNK#",
                t_unk_tag="#T_UNK#"):
    with chainer.using_config("train", False), chainer.no_backprop_mode():
        print("sample")
        sample_greedy, score, attn_list = encdec(src_batch,
                                                 50,
                                                 src_mask,
                                                 use_best_for_sample=True,
                                                 need_score=True)

        #                 sample, score = encdec(src_batch, 50, src_mask, use_best_for_sample = False)
        assert len(src_batch[0].data) == len(tgt_batch[0].data)
        assert len(sample_greedy[0]) == len(src_batch[0].data)

        debatched_src = de_batch(src_batch,
                                 mask=src_mask,
                                 eos_idx=None,
                                 is_variable=True)
        debatched_tgt = de_batch(tgt_batch, eos_idx=eos_idx, is_variable=True)
        debatched_sample = de_batch(sample_greedy, eos_idx=eos_idx)

        sample_random, score_random, attn_list_random = encdec(
            src_batch,
            50,
            src_mask,
            use_best_for_sample=False,
            need_score=True)

        debatched_sample_random = de_batch(sample_random, eos_idx=eos_idx)

        for sent_num in six.moves.range(len(debatched_src)):
            if max_nb is not None and sent_num > max_nb:
                break
            src_idx_seq = debatched_src[sent_num]
            tgt_idx_seq = debatched_tgt[sent_num]
            sample_idx_seq = debatched_sample[sent_num]
            sample_random_idx_seq = debatched_sample_random[sent_num]

            print("sent num", sent_num)

            for name, seq, unk_tag, indexer, this_eos_idx in six.moves.zip(
                    "src tgt sample sample_random".split(" "), [
                        src_idx_seq, tgt_idx_seq, sample_idx_seq,
                        sample_random_idx_seq
                    ], [s_unk_tag, t_unk_tag, t_unk_tag, t_unk_tag],
                [src_indexer, tgt_indexer, tgt_indexer, tgt_indexer],
                [None, eos_idx, eos_idx, eos_idx]):
                print(name, "idx:", seq)
                print(
                    name, "raw:", " ".join(
                        indexer.deconvert_swallow(
                            seq, unk_tag=unk_tag,
                            eos_idx=this_eos_idx)).encode('utf-8'))
                print(
                    name, "postp:",
                    indexer.deconvert(seq,
                                      unk_tag=unk_tag,
                                      eos_idx=this_eos_idx).encode('utf-8'))
Example #9
0
def greedy_batch_translate(encdec,
                           eos_idx,
                           src_data,
                           batch_size=80,
                           gpu=None,
                           get_attention=False,
                           nb_steps=50,
                           reverse_src=False,
                           reverse_tgt=False,
                           use_chainerx=False):
    with chainer.using_config("train", False), chainer.no_backprop_mode():
        if encdec.encdec_type() == "ff":
            result = encdec.greedy_batch_translate(src_data,
                                                   mb_size=batch_size,
                                                   nb_steps=nb_steps)
            if get_attention:
                dummy_attention = []
                for src, tgt in six.moves.zip(src_data, result):
                    dummy_attention.append(
                        np.zeros((len(src), len(tgt)), dtype=np.float32))
                return result, dummy_attention
            else:
                return result

        nb_ex = len(src_data)
        nb_batch = nb_ex // batch_size + (1 if nb_ex % batch_size != 0 else 0)
        res = []
        attn_all = []
        for i in six.moves.range(nb_batch):
            current_batch_raw_data = src_data[i * batch_size:(i + 1) *
                                              batch_size]

            if reverse_src:
                current_batch_raw_data_new = []
                for src_side in current_batch_raw_data:
                    current_batch_raw_data_new.append(src_side[::-1])
                current_batch_raw_data = current_batch_raw_data_new

            src_batch, src_mask = make_batch_src(current_batch_raw_data,
                                                 gpu=gpu,
                                                 use_chainerx=use_chainerx)
            sample_greedy, score, attn_list = encdec(
                src_batch,
                nb_steps,
                src_mask,
                use_best_for_sample=True,
                keep_attn_values=get_attention)
            deb = de_batch(sample_greedy,
                           mask=None,
                           eos_idx=eos_idx,
                           is_variable=False)
            res += deb
            if get_attention:
                deb_attn = de_batch(attn_list,
                                    mask=None,
                                    eos_idx=None,
                                    is_variable=True,
                                    raw=True,
                                    reverse=reverse_tgt)
                attn_all += deb_attn

        if reverse_tgt:
            new_res = []
            for t in res:
                if t[-1] == eos_idx:
                    new_res.append(t[:-1][::-1] + [t[-1]])
                else:
                    new_res.append(t[::-1])

            res = new_res

        if get_attention:
            assert not reverse_tgt, "not implemented"
            return res, attn_all
        else:
            return res
Example #10
0
def greedy_batch_translate(encdec,
                           eos_idx,
                           src_data,
                           batch_size=80,
                           gpu=None,
                           get_attention=False,
                           nb_steps=50,
                           reverse_src=False,
                           reverse_tgt=False):
    nb_ex = len(src_data)
    nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0)
    res = []
    attn_all = []
    for i in range(nb_batch):
        current_batch_raw_data = src_data[i * batch_size:(i + 1) * batch_size]

        if reverse_src:
            current_batch_raw_data_new = []
            for src_side in current_batch_raw_data:
                current_batch_raw_data_new.append(src_side[::-1])
            current_batch_raw_data = current_batch_raw_data_new

        src_batch, src_mask = make_batch_src(current_batch_raw_data,
                                             gpu=gpu,
                                             volatile="on")
        sample_greedy, score, attn_list = encdec(
            src_batch,
            nb_steps,
            src_mask,
            use_best_for_sample=True,
            keep_attn_values=get_attention,
            mode="test")
        deb = de_batch(sample_greedy,
                       mask=None,
                       eos_idx=eos_idx,
                       is_variable=False)
        res += deb
        if get_attention:
            deb_attn = de_batch(attn_list,
                                mask=None,
                                eos_idx=None,
                                is_variable=True,
                                raw=True,
                                reverse=reverse_tgt)
            attn_all += deb_attn

    if reverse_tgt:
        new_res = []
        for t in res:
            if t[-1] == eos_idx:
                new_res.append(t[:-1][::-1] + [t[-1]])
            else:
                new_res.append(t[::-1])

        res = new_res

    if get_attention:
        assert not reverse_tgt, "not implemented"
        return res, attn_all
    else:
        return res