def batch_align(encdec, eos_idx, src_tgt_data, batch_size=80, gpu=None): nb_ex = len(src_tgt_data) nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0) sum_loss = 0 attn_all = [] for i in range(nb_batch): current_batch_raw_data = src_tgt_data[i * batch_size:(i + 1) * batch_size] # print current_batch_raw_data src_batch, tgt_batch, src_mask, arg_sort = make_batch_src_tgt( current_batch_raw_data, eos_idx=eos_idx, gpu=gpu, volatile="on", need_arg_sort=True) loss, attn_list = encdec(src_batch, tgt_batch, src_mask, keep_attn_values=True) deb_attn = de_batch(attn_list, mask=None, eos_idx=None, is_variable=True, raw=True) assert len(arg_sort) == len(deb_attn) de_sorted_attn = [None] * len(deb_attn) for xpos in xrange(len(arg_sort)): original_pos = arg_sort[xpos] de_sorted_attn[original_pos] = deb_attn[xpos] attn_all += de_sorted_attn sum_loss += float(loss.data) return sum_loss, attn_all
def test_multiple_length_variable_raw(self): batch = [Variable(np.array(x, dtype=np.int32)) for x in [[1, 3, 4, 8], [1, 5, 6, 9], [[7, 9], [5, 8]], [10]]] seq_list = de_batch(batch, is_variable=True, raw=True) assert len(seq_list) == 4 for seq1, seq2 in zip(seq_list, [[1, 1, [7, 9], 10], [3, 5, [5, 8]], [4, 6], [8, 9]]): assert len(seq1) == len(seq2) for elem1, elem2 in zip(seq1, seq2): assert np.all(elem1 == elem2)
def test_multiple_length_variable(self): batch = [ Variable(np.array(x, dtype=np.int32)) for x in [[1, 3, 4, 8], [1, 5, 6, 9], [7, 5], [10]] ] seq_list = de_batch(batch, is_variable=True) assert seq_list == [[1, 1, 7, 10], [3, 5, 5], [4, 6], [8, 9]]
def test_multiple_length_eos_idx(self): batch = [ np.array([1, 3, 4, 8]), np.array([3, 3, 6, 9]), np.array([7, 5]), np.array([10]) ] seq_list = de_batch(batch, eos_idx=3) assert seq_list == [[1, 3], [3], [4, 6], [8, 9]]
def test_multiple_length(self): batch = [ np.array([1, 3, 4, 8]), np.array([1, 5, 6, 9]), np.array([7, 5]), np.array([10]) ] seq_list = de_batch(batch) assert seq_list == [[1, 1, 7, 10], [3, 5, 5], [4, 6], [8, 9]]
def test_mask3(self): batch = [ np.array([1, 3, 4, 8]), np.array([1, 5, 6, 9]), np.array([7, 5, 3, 4]) ] mask = [ np.array([True, True, False, True]), np.array([True, True, False, False]) ] seq_list = de_batch(batch, mask=mask) assert seq_list == [[1, 1, 7], [3, 5, 5], [4], [8, 9]]
def train_once_reinf(src_batch, tgt_batch, src_mask): # , lexicon_matrix = None): t0 = time.clock() encdec.zerograds() t1 = time.clock() from nmt_chainer.utilities import utils test_ref = utils.de_batch(tgt_batch, is_variable=True) reinf_loss = encdec.get_reinf_loss(src_batch, src_mask, eos_idx, test_ref, nb_steps=50, nb_samples=5, use_best_for_sample=False, temperature=None, mode="test") t2 = time.clock() reinf_loss.backward() t3 = time.clock() optimizer.update() t4 = time.clock() print "reinf loss:", reinf_loss.data, reinf_loss.data / len(src_batch) print " time %f zgrad:%f fwd:%f bwd:%f upd:%f" % (t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3) return float(reinf_loss.data), len(src_batch)
def sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx, max_nb=None, s_unk_tag="#S_UNK#", t_unk_tag="#T_UNK#"): with chainer.using_config("train", False), chainer.no_backprop_mode(): print("sample") sample_greedy, score, attn_list = encdec(src_batch, 50, src_mask, use_best_for_sample=True, need_score=True) # sample, score = encdec(src_batch, 50, src_mask, use_best_for_sample = False) assert len(src_batch[0].data) == len(tgt_batch[0].data) assert len(sample_greedy[0]) == len(src_batch[0].data) debatched_src = de_batch(src_batch, mask=src_mask, eos_idx=None, is_variable=True) debatched_tgt = de_batch(tgt_batch, eos_idx=eos_idx, is_variable=True) debatched_sample = de_batch(sample_greedy, eos_idx=eos_idx) sample_random, score_random, attn_list_random = encdec( src_batch, 50, src_mask, use_best_for_sample=False, need_score=True) debatched_sample_random = de_batch(sample_random, eos_idx=eos_idx) for sent_num in six.moves.range(len(debatched_src)): if max_nb is not None and sent_num > max_nb: break src_idx_seq = debatched_src[sent_num] tgt_idx_seq = debatched_tgt[sent_num] sample_idx_seq = debatched_sample[sent_num] sample_random_idx_seq = debatched_sample_random[sent_num] print("sent num", sent_num) for name, seq, unk_tag, indexer, this_eos_idx in six.moves.zip( "src tgt sample sample_random".split(" "), [ src_idx_seq, tgt_idx_seq, sample_idx_seq, sample_random_idx_seq ], [s_unk_tag, t_unk_tag, t_unk_tag, t_unk_tag], [src_indexer, tgt_indexer, tgt_indexer, tgt_indexer], [None, eos_idx, eos_idx, eos_idx]): print(name, "idx:", seq) print( name, "raw:", " ".join( indexer.deconvert_swallow( seq, unk_tag=unk_tag, eos_idx=this_eos_idx)).encode('utf-8')) print( name, "postp:", indexer.deconvert(seq, unk_tag=unk_tag, eos_idx=this_eos_idx).encode('utf-8'))
def greedy_batch_translate(encdec, eos_idx, src_data, batch_size=80, gpu=None, get_attention=False, nb_steps=50, reverse_src=False, reverse_tgt=False, use_chainerx=False): with chainer.using_config("train", False), chainer.no_backprop_mode(): if encdec.encdec_type() == "ff": result = encdec.greedy_batch_translate(src_data, mb_size=batch_size, nb_steps=nb_steps) if get_attention: dummy_attention = [] for src, tgt in six.moves.zip(src_data, result): dummy_attention.append( np.zeros((len(src), len(tgt)), dtype=np.float32)) return result, dummy_attention else: return result nb_ex = len(src_data) nb_batch = nb_ex // batch_size + (1 if nb_ex % batch_size != 0 else 0) res = [] attn_all = [] for i in six.moves.range(nb_batch): current_batch_raw_data = src_data[i * batch_size:(i + 1) * batch_size] if reverse_src: current_batch_raw_data_new = [] for src_side in current_batch_raw_data: current_batch_raw_data_new.append(src_side[::-1]) current_batch_raw_data = current_batch_raw_data_new src_batch, src_mask = make_batch_src(current_batch_raw_data, gpu=gpu, use_chainerx=use_chainerx) sample_greedy, score, attn_list = encdec( src_batch, nb_steps, src_mask, use_best_for_sample=True, keep_attn_values=get_attention) deb = de_batch(sample_greedy, mask=None, eos_idx=eos_idx, is_variable=False) res += deb if get_attention: deb_attn = de_batch(attn_list, mask=None, eos_idx=None, is_variable=True, raw=True, reverse=reverse_tgt) attn_all += deb_attn if reverse_tgt: new_res = [] for t in res: if t[-1] == eos_idx: new_res.append(t[:-1][::-1] + [t[-1]]) else: new_res.append(t[::-1]) res = new_res if get_attention: assert not reverse_tgt, "not implemented" return res, attn_all else: return res
def greedy_batch_translate(encdec, eos_idx, src_data, batch_size=80, gpu=None, get_attention=False, nb_steps=50, reverse_src=False, reverse_tgt=False): nb_ex = len(src_data) nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0) res = [] attn_all = [] for i in range(nb_batch): current_batch_raw_data = src_data[i * batch_size:(i + 1) * batch_size] if reverse_src: current_batch_raw_data_new = [] for src_side in current_batch_raw_data: current_batch_raw_data_new.append(src_side[::-1]) current_batch_raw_data = current_batch_raw_data_new src_batch, src_mask = make_batch_src(current_batch_raw_data, gpu=gpu, volatile="on") sample_greedy, score, attn_list = encdec( src_batch, nb_steps, src_mask, use_best_for_sample=True, keep_attn_values=get_attention, mode="test") deb = de_batch(sample_greedy, mask=None, eos_idx=eos_idx, is_variable=False) res += deb if get_attention: deb_attn = de_batch(attn_list, mask=None, eos_idx=None, is_variable=True, raw=True, reverse=reverse_tgt) attn_all += deb_attn if reverse_tgt: new_res = [] for t in res: if t[-1] == eos_idx: new_res.append(t[:-1][::-1] + [t[-1]]) else: new_res.append(t[::-1]) res = new_res if get_attention: assert not reverse_tgt, "not implemented" return res, attn_all else: return res