def test_multibatch(self): Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl = 29, 37, 13, 17, 7, 12, 19, 33 encdec = EncoderDecoderNaive(Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl) eos_idx = Vo - 1 raw_seq1 = [2, 5, 0, 3], [4, 6] raw_seq2 = [2, 5, 4, 3, 0, 0, 1, 11, 3], [4, 8, 9, 12, 0] raw_seq3 = [2, 5, 4, 3, 0, 11, 3], [5, 7, 1, 4, 4, 1, 0, 0, 5, 5, 3, 4, 6, 7, 8] raw_seq4 = [5, 3, 0, 0, 1, 11, 3], [0, 0, 1, 1] trg_data = [raw_seq1, raw_seq2, raw_seq3, raw_seq4] src_batch, tgt_batch, src_mask = utils.make_batch_src_tgt( trg_data, eos_idx=eos_idx) loss, attn = encdec(src_batch, tgt_batch, src_mask) total_loss_naive = 0 total_length = 0 for i in range(len(trg_data)): raw_s_raw = trg_data[i] raw_s = [raw_s_raw[0], raw_s_raw[1] + [eos_idx]] input_seq = [ Variable(np.array([v], dtype=np.int32)) for v in raw_s[0] ] tgt_seq = [ Variable(np.array([v], dtype=np.int32)) for v in raw_s[1] ] loss_naive, attn_naive = encdec.naive_call(input_seq, tgt_seq, None) total_loss_naive += float(loss_naive.data) * len(raw_s[1]) total_length += len(raw_s[1]) assert abs(total_loss_naive / total_length - float(loss.data)) < 1e-6
def batch_align(encdec, eos_idx, src_tgt_data, batch_size=80, gpu=None): nb_ex = len(src_tgt_data) nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0) sum_loss = 0 attn_all = [] for i in range(nb_batch): current_batch_raw_data = src_tgt_data[i * batch_size:(i + 1) * batch_size] # print current_batch_raw_data src_batch, tgt_batch, src_mask, arg_sort = make_batch_src_tgt( current_batch_raw_data, eos_idx=eos_idx, gpu=gpu, volatile="on", need_arg_sort=True) loss, attn_list = encdec(src_batch, tgt_batch, src_mask, keep_attn_values=True) deb_attn = de_batch(attn_list, mask=None, eos_idx=None, is_variable=True, raw=True) assert len(arg_sort) == len(deb_attn) de_sorted_attn = [None] * len(deb_attn) for xpos in xrange(len(arg_sort)): original_pos = arg_sort[xpos] de_sorted_attn[original_pos] = deb_attn[xpos] attn_all += de_sorted_attn sum_loss += float(loss.data) return sum_loss, attn_all
def sample_extension(trainer): encdec = trainer.updater.get_optimizer("main").target iterator = trainer.updater.get_iterator("main") mb_raw = iterator.peek() def s_unk_tag(num, utag): return "S_UNK_%i" % utag def t_unk_tag(num, utag): return "T_UNK_%i" % utag try: if encdec.encdec_type() == "ff": src_seqs, tgt_seqs = list(six.moves.zip(*mb_raw)) sample_once_ff(encdec, src_seqs, tgt_seqs, src_indexer, tgt_indexer, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag) else: src_batch, tgt_batch, src_mask = make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, need_arg_sort=False, use_chainerx = use_chainerx) sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag) except CudaException: log.warn("CUDARuntimeError during sample. Skipping sample")
def convert_mb(mb_raw, device): return make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=device, volatile="off", need_arg_sort=False)
def sample_extension(trainer): encdec = trainer.updater.get_optimizer("main").target iterator = trainer.updater.get_iterator("main") mb_raw = iterator.peek() src_batch, tgt_batch, src_mask = make_batch_src_tgt( mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, volatile="on", need_arg_sort=False) def s_unk_tag(num, utag): return "S_UNK_%i" % utag def t_unk_tag(num, utag): return "T_UNK_%i" % utag sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
def convert_mb(mb_raw, device): return make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=device, need_arg_sort=False, use_chainerx=use_chainerx)