コード例 #1
0
ファイル: models_test.py プロジェクト: Tzawa/knmt
    def test_multibatch(self):
        Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl = 29, 37, 13, 17, 7, 12, 19, 33
        encdec = EncoderDecoderNaive(Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl)
        eos_idx = Vo - 1

        raw_seq1 = [2, 5, 0, 3], [4, 6]
        raw_seq2 = [2, 5, 4, 3, 0, 0, 1, 11, 3], [4, 8, 9, 12, 0]
        raw_seq3 = [2, 5, 4, 3, 0, 11,
                    3], [5, 7, 1, 4, 4, 1, 0, 0, 5, 5, 3, 4, 6, 7, 8]
        raw_seq4 = [5, 3, 0, 0, 1, 11, 3], [0, 0, 1, 1]

        trg_data = [raw_seq1, raw_seq2, raw_seq3, raw_seq4]
        src_batch, tgt_batch, src_mask = utils.make_batch_src_tgt(
            trg_data, eos_idx=eos_idx)

        loss, attn = encdec(src_batch, tgt_batch, src_mask)

        total_loss_naive = 0
        total_length = 0
        for i in range(len(trg_data)):
            raw_s_raw = trg_data[i]
            raw_s = [raw_s_raw[0], raw_s_raw[1] + [eos_idx]]
            input_seq = [
                Variable(np.array([v], dtype=np.int32)) for v in raw_s[0]
            ]
            tgt_seq = [
                Variable(np.array([v], dtype=np.int32)) for v in raw_s[1]
            ]
            loss_naive, attn_naive = encdec.naive_call(input_seq, tgt_seq,
                                                       None)
            total_loss_naive += float(loss_naive.data) * len(raw_s[1])
            total_length += len(raw_s[1])

        assert abs(total_loss_naive / total_length - float(loss.data)) < 1e-6
コード例 #2
0
def batch_align(encdec, eos_idx, src_tgt_data, batch_size=80, gpu=None):
    nb_ex = len(src_tgt_data)
    nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0)
    sum_loss = 0
    attn_all = []
    for i in range(nb_batch):
        current_batch_raw_data = src_tgt_data[i * batch_size:(i + 1) *
                                              batch_size]
        #         print current_batch_raw_data
        src_batch, tgt_batch, src_mask, arg_sort = make_batch_src_tgt(
            current_batch_raw_data,
            eos_idx=eos_idx,
            gpu=gpu,
            volatile="on",
            need_arg_sort=True)
        loss, attn_list = encdec(src_batch,
                                 tgt_batch,
                                 src_mask,
                                 keep_attn_values=True)
        deb_attn = de_batch(attn_list,
                            mask=None,
                            eos_idx=None,
                            is_variable=True,
                            raw=True)

        assert len(arg_sort) == len(deb_attn)
        de_sorted_attn = [None] * len(deb_attn)
        for xpos in xrange(len(arg_sort)):
            original_pos = arg_sort[xpos]
            de_sorted_attn[original_pos] = deb_attn[xpos]

        attn_all += de_sorted_attn
        sum_loss += float(loss.data)
    return sum_loss, attn_all
コード例 #3
0
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        try:
            if encdec.encdec_type() == "ff":
                src_seqs, tgt_seqs = list(six.moves.zip(*mb_raw))
                sample_once_ff(encdec, src_seqs, tgt_seqs, src_indexer, tgt_indexer, max_nb=20,
                    s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
            else:
    
                src_batch, tgt_batch, src_mask = make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, need_arg_sort=False,
                                                                    use_chainerx = use_chainerx)
        
                sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx,
                            max_nb=20,
                            s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
        except CudaException:
            log.warn("CUDARuntimeError during sample. Skipping sample")
コード例 #4
0
 def convert_mb(mb_raw, device):
     return make_batch_src_tgt(mb_raw,
                               eos_idx=eos_idx,
                               padding_idx=0,
                               gpu=device,
                               volatile="off",
                               need_arg_sort=False)
コード例 #5
0
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        src_batch, tgt_batch, src_mask = make_batch_src_tgt(
            mb_raw,
            eos_idx=eos_idx,
            padding_idx=0,
            gpu=gpu,
            volatile="on",
            need_arg_sort=False)

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        sample_once(encdec,
                    src_batch,
                    tgt_batch,
                    src_mask,
                    src_indexer,
                    tgt_indexer,
                    eos_idx,
                    max_nb=20,
                    s_unk_tag=s_unk_tag,
                    t_unk_tag=t_unk_tag)
コード例 #6
0
 def convert_mb(mb_raw, device):
     return make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=device, need_arg_sort=False,
                                     use_chainerx=use_chainerx)