Ejemplo n.º 1
0
    def _print_hyp(self, out_count, src_ids, src_id2word, seqlist):

        if out_count < 3:
            srcwords = _convert_to_words_batchfirst(src_ids, src_id2word)
            seqwords = _convert_to_words(seqlist, src_id2word)
            outsrc = 'SRC: {}\n'.format(' '.join(srcwords[0])).encode('utf-8')
            outline = 'GEN: {}\n'.format(' '.join(seqwords[0])).encode('utf-8')
            sys.stdout.buffer.write(outsrc)
            sys.stdout.buffer.write(outline)
            out_count += 1
            sys.stdout.flush()
        return out_count
def translate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	# f = open(os.path.join(test_path_out, 'translate.txt'), 'w') -> use proper encoding
	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				decoder_outputs, decoder_hidden, other = model(src=src_ids,
																is_training=False,
																att_key_feats=src_probs,
																beam_width=beam_width)
				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# write to file
				seqlist = other['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
def att_plot(test_set, load_dir, plot_path, use_gpu, max_seq_len, beam_width):

	"""
		generate attention alignment plots
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
			max_seq_len
		Returns:

	"""

	# check devide
	print('cuda available: {}'.format(torch.cuda.is_available()))
	use_gpu = use_gpu and torch.cuda.is_available()

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)

	# in plotting mode always turn off beam search
	model.set_beam_width(beam_width=0)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	print('ptr_net {}'.format(model.ptr_net))

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	# start eval
	model.eval()
	match = 0
	total = 0
	count = 0
	with torch.no_grad():
		for batch in test_batches:

			src_ids = batch['src_word_ids']
			src_lengths = batch['src_sentence_lengths']
			tgt_ids = batch['tgt_word_ids']
			tgt_lengths = batch['tgt_sentence_lengths']
			src_probs = None
			if 'src_ddfd_probs' in batch:
				src_probs =  batch['src_ddfd_probs']
				src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

			src_ids = _convert_to_tensor(src_ids, use_gpu)
			tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

			decoder_outputs, decoder_hidden, other = model(src_ids, tgt_ids,
															is_training=False,
															att_key_feats=src_probs,
															beam_width=0)
			# Evaluation
			# default batch_size = 1
			# attention: 31 * [1 x 1 x 32] ( tgt_len(query_len) * [ batch_size x 1 x src_len(key_len)] )
			attention = other['attention_score']
			seqlist = other['sequence'] # traverse over time not batch
			bsize = test_set.batch_size
			max_seq = test_set.max_seq_len
			vocab_size = len(test_set.tgt_word2id)
			for idx in range(len(decoder_outputs)): # loop over max_seq
				step = idx
				step_output = decoder_outputs[idx] # 64 x vocab_size
				# count correct
				target = tgt_ids[:, step+1]
				non_padding = target.ne(PAD)
				correct = seqlist[step].view(-1).eq(target).masked_select(non_padding).sum().item()
				match += correct
				total += non_padding.sum().item()

			# Print sentence by sentence
			srcwords = _convert_to_words_batchfirst(src_ids, test_set.src_id2word)
			refwords = _convert_to_words_batchfirst(tgt_ids[:,1:], test_set.tgt_id2word)
			seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
			# print(type(attention))
			# print(len(attention))
			# print(type(attention[0]))
			# print(attention[0].size())
			# input('...')
			n_q = len(attention)
			n_k = attention[0].size(2)
			b_size =  attention[0].size(0)
			att_score = torch.empty(n_q, n_k, dtype=torch.float)
			# att_score = np.empty([n_q, n_k])

			for i in range(len(seqwords)): # loop over sentences
				outline_src = ' '.join(srcwords[i])
				outline_ref = ' '.join(refwords[i])
				outline_gen = ' '.join(seqwords[i])
				print('SRC: {}'.format(outline_src))
				print('REF: {}'.format(outline_ref))
				print('GEN: {}'.format(outline_gen))
				for j in range(len(attention)):
					# i: idx of batch
					# j: idx of query
					gen = seqwords[i][j]
					ref = refwords[i][j]
					att = attention[j][i]
					# record att scores
					att_score[j] = att

					# print('REF:GEN - {}:{}'.format(ref,gen))
					# print('{}th ATT size: {}'.format(j, attention[j][i].size()))
					# print(att)
					# print(torch.argmax(att))
					# print(sum(sum(att)))
					# input('Press enter to continue ...')

				# plotting
				# print(att_score)
				loc_eos_k = srcwords[i].index('</s>') + 1
				loc_eos_q = seqwords[i].index('</s>') + 1
				loc_eos_ref = refwords[i].index('</s>') + 1
				print('eos_k: {}, eos_q: {}'.format(loc_eos_k, loc_eos_q))
				att_score_trim = att_score[:loc_eos_q, :loc_eos_k] # each row (each query) sum up to 1
				print(att_score_trim)
				print('\n')
				# import pdb; pdb.set_trace()

				choice = input('Plot or not ? - y/n\n')
				if choice:
					if choice.lower()[0] == 'y':
						print('plotting ...')
						plot_dir = os.path.join(plot_path, '{}.png'.format(count))
						src = srcwords[i][:loc_eos_k]
						hyp = seqwords[i][:loc_eos_q]
						ref = refwords[i][:loc_eos_ref]
						# x-axis: src; y-axis: hyp
						# plot_alignment(att_score_trim.numpy(), plot_dir, src=src, hyp=hyp, ref=ref)
						plot_alignment(att_score_trim.numpy(), plot_dir, src=src, hyp=hyp, ref=None) # no ref
						count += 1
						input('Press enter to continue ...')


	if total == 0:
		accuracy = float('nan')
	else:
		accuracy = match / total
	print(saccuracy)
def evaluate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, seqrev=False):

	"""
		with reference tgt given - Run translation.
		Args:
			test_set: test dataset
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
		Returns:
			accuracy (excluding PAD tokens)
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	model.set_beam_width(beam_width)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)


	# f = open(os.path.join(test_path_out, 'test.txt'), 'w')
	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				tgt_ids = batch['tgt_word_ids']
				tgt_lengths = batch['tgt_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

				decoder_outputs, decoder_hidden, other = model(src_ids, tgt_ids,
																is_training=False,
																att_key_feats=src_probs,
																beam_width=beam_width)

				# Evaluation
				seqlist = other['sequence'] # traverse over time not batch
				if beam_width > 1:
					full_seqlist = other['topk_sequence']
					decoder_outputs = decoder_outputs[:-1]
				for step, step_output in enumerate(decoder_outputs):
					target = tgt_ids[:, step+1]
					non_padding = target.ne(PAD)
					correct = seqlist[step].view(-1).eq(target)
						.masked_select(non_padding).sum().item()
					match += correct
					total += non_padding.sum().item()

				# write to file
				seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					if i == 0:
						print(outline)
				sys.stdout.flush()

		if total == 0:
			accuracy = float('nan')
		else:
			accuracy = match / total
Ejemplo n.º 5
0
def translate(test_set, model, test_path_out, use_gpu,
	max_seq_len, beam_width, device, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			use_gpu: on gpu/cpu
	"""

	# reset batch_size:
	model.max_seq_len = max_seq_len
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	test_set.construct_batches(is_train=False)
	evaliter = iter(test_set.iter_loader)
	print('num batches: {}'.format(len(evaliter)))

	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		with torch.no_grad():
			for idx in range(len(evaliter)):
				batch_items = evaliter.next()
				print(idx+1, len(evaliter))

				# load data
				src_ids = batch_items['srcid'][0]
				src_lengths = batch_items['srclen']
				tgt_ids = batch_items['tgtid'][0]
				tgt_lengths = batch_items['tgtlen']
				src_len = max(src_lengths)
				tgt_len = max(tgt_lengths)
				src_ids = src_ids[:,:src_len].to(device=device)
				tgt_ids = tgt_ids[:,:tgt_len].to(device=device)

				decoder_outputs, decoder_hidden, other = model(src=src_ids,
					src_lens=src_lengths, is_training=False,
					beam_width=beam_width, use_gpu=use_gpu)

				# write to file
				seqlist = other['sequence']
				seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
				for i in range(len(seqwords)):
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '<spc>':
							words.append(' ')
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						if test_set.use_type == 'word':
							outline = ' '.join(words)
						elif test_set.use_type == 'char':
							outline = ''.join(words)
					f.write('{}\n'.format(outline))
					# import pdb; pdb.set_trace()

				sys.stdout.flush()
Ejemplo n.º 6
0
def att_plot(test_set, model, plot_path, use_gpu, max_seq_len, beam_width, device):

	"""
		generate attention alignment plots
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
			max_seq_len
		Returns:

	"""

	beam_width = 1
	model.max_seq_len = max_seq_len
	print('max seq len {}'.format(model.max_seq_len))


	# load test
	test_set.construct_batches(is_train=False)
	evaliter = iter(test_set.iter_loader)

	# start eval
	model.eval()
	with torch.no_grad():
		for idx in range(len(evaliter)):
			batch_items = evaliter.next()

			# load data
			src_ids = batch_items['srcid'][0]
			src_lengths = batch_items['srclen']
			tgt_ids = batch_items['tgtid'][0]
			tgt_lengths = batch_items['tgtlen']
			src_len = max(src_lengths)
			tgt_len = max(tgt_lengths)
			src_ids = src_ids[:,:src_len].to(device=device)
			tgt_ids = tgt_ids[:,:tgt_len].to(device=device)

			decoder_outputs, decoder_hidden, other = model(src_ids,
				src_lens=src_lengths, tgt=tgt_ids,
				is_training=False, beam_width=beam_width)

			# Evaluation
			# default batch_size = 1
			# attention: 31 * [1 x 1 x 32]
			# 	(tgt_len(query_len) * [ batch_size x 1 x src_len(key_len)]
			attention = other['attention_score']
			seqlist = other['sequence'] # traverse over time not batch
			bsize = test_set.batch_size
			max_seq = test_set.max_seq_len
			vocab_size = len(test_set.tgt_word2id)

			# Print sentence by sentence
			srcwords = _convert_to_words_batchfirst(src_ids, test_set.src_id2word)
			refwords = _convert_to_words_batchfirst(tgt_ids[:,1:], test_set.tgt_id2word)
			seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)

			n_q = len(attention)
			n_k = attention[0].size(2)
			b_size =  attention[0].size(0)
			att_score = torch.empty(n_q, n_k, dtype=torch.float)
			# att_score = np.empty([n_q, n_k])

			for i in range(len(seqwords)): # loop over sentences
				outline_src = ' '.join(srcwords[i])
				outline_ref = ' '.join(refwords[i])
				outline_gen = ' '.join(seqwords[i])
				print('SRC: {}'.format(outline_src))
				print('REF: {}'.format(outline_ref))
				print('GEN: {}'.format(outline_gen))
				for j in range(len(attention)):
					# i: idx of batch
					# j: idx of query
					gen = seqwords[i][j]
					ref = refwords[i][j]
					att = attention[j][i]
					# record att scores
					att_score[j] = att

				# plotting
				loc_eos_k = srcwords[i].index('</s>') + 1
				loc_eos_q = seqwords[i].index('</s>') + 1
				loc_eos_ref = refwords[i].index('</s>') + 1
				print('eos_k: {}, eos_q: {}'.format(loc_eos_k, loc_eos_q))
				att_score_trim = att_score[:loc_eos_q, :loc_eos_k]
				# each row (each query) sum up to 1
				print(att_score_trim)
				print('\n')
				# import pdb; pdb.set_trace()

				choice = input('Plot or not ? - y/n\n')
				if choice:
					if choice.lower()[0] == 'y':
						print('plotting ...')
						plot_dir = os.path.join(plot_path, '{}.png'.format(count))
						src = srcwords[i][:loc_eos_k]
						hyp = seqwords[i][:loc_eos_q]
						ref = refwords[i][:loc_eos_ref]
						# x-axis: src; y-axis: hyp
						# plot_alignment(att_score_trim.numpy(),
						# 	plot_dir, src=src, hyp=hyp, ref=ref)
						plot_alignment(att_score_trim.numpy(),
							plot_dir, src=src, hyp=hyp, ref=None) # no ref
						count += 1
						input('Press enter to continue ...')
Ejemplo n.º 7
0
def translate_batch(test_set, model, test_path_out, use_gpu,
	max_seq_len, beam_width, device, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# reset batch_size:
	model.max_seq_len = max_seq_len
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	test_set.construct_batches(is_train=False)
	evaliter = iter(test_set.iter_loader)
	print('num batches: {}'.format(len(evaliter)))
	print('batch_size: {}'.format(test_set.batch_size))

	model.eval()
	with torch.no_grad():

		# select batch
		n_total = len(evaliter)
		iter_idx = 0
		per_iter = 500 # 1892809 lines; 100/batch; 38 iterations
		st = iter_idx * per_iter
		ed = min((iter_idx + 1) * per_iter, n_total)
		f = open(os.path.join(test_path_out, '{:04d}.txt'.format(iter_idx)), 'w', encoding="utf8")

		for idx in range(len(evaliter)):
			batch_items = evaliter.next()
			if idx < st:
				continue
			elif idx >= ed:
				break
			print(idx, ed)

			# load data
			src_ids = batch_items['srcid'][0]
			src_lengths = batch_items['srclen']
			tgt_ids = batch_items['tgtid'][0]
			tgt_lengths = batch_items['tgtlen']
			src_len = max(src_lengths)
			tgt_len = max(tgt_lengths)
			src_ids = src_ids[:,:src_len].to(device=device)
			tgt_ids = tgt_ids[:,:tgt_len].to(device=device)

			decoder_outputs, decoder_hidden, other = model(src=src_ids, src_lens=src_lengths,
				is_training=False, beam_width=beam_width, use_gpu=use_gpu)

			# memory usage
			mem_kb, mem_mb, mem_gb = get_memory_alloc()
			mem_mb = round(mem_mb, 2)
			print('Memory used: {0:.2f} MB'.format(mem_mb))

			# write to file
			seqlist = other['sequence']
			seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
			for i in range(len(seqwords)):
				if src_lengths[i] == 0:
					continue
				words = []
				for word in seqwords[i]:
					if word == '<pad>':
						continue
					elif word == '<spc>':
						words.append(' ')
					elif word == '</s>':
						break
					else:
						words.append(word)
				if len(words) == 0:
					outline = ''
				else:
					if seqrev:
						words = words[::-1]
					if test_set.use_type == 'word':
						outline = ' '.join(words)
					elif test_set.use_type == 'char':
						outline = ''.join(words)
				f.write('{}\n'.format(outline))

			sys.stdout.flush()
Ejemplo n.º 8
0
def translate(test_set, load_dir, test_path_out,
	use_gpu, max_seq_len, beam_width, mode='gec', seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	# fix compatibility
	model.check_classvar('shared_embed')
	model.check_classvar('additional_key_size')
	model.check_classvar('gec_num_bilstm_dec')
	model.check_classvar('num_unilstm_enc')
	model.check_classvar('residual')
	model.check_classvar('add_discriminator')
	if model.ptr_net == 'none': model.ptr_net = 'null'
	model.to(device)

	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	print('--- constrcut gec test set ---')
	if type(test_set.tsv_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch and model.dd_additional_key_size > 0:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				tgt_ids = None

				if mode.lower() == 'gec' and beam_width <= 1:
					gec_dd_decoder_outputs, gec_dd_dec_hidden, gec_dd_ret_dict, \
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'gec' and beam_width > 1:
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'dd':
					decoder_outputs, dec_hidden, ret_dict = \
							model.dd_eval(src_ids, tgt_ids, is_training=False,
										dd_att_key_feats=src_probs, beam_width=beam_width)
				else:
					assert False, 'Unrecognised eval mode - choose from gec/dd'

				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# gec output write to file
				# import pdb; pdb.set_trace()
				seqlist = ret_dict['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
Ejemplo n.º 9
0
def acous_att_plot(test_set, load_dir, plot_path, use_gpu, max_seq_len,
                   beam_width):
    """
		generate attention alignment plots
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
			max_seq_len
		Returns:

	"""

    # import pdb; pdb.set_trace()
    use_gpu = False

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.cpu()
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0
    print('total #batches: {}'.format(len(evaliter)))

    # start eval
    count = 0
    model.eval()
    with torch.no_grad():
        for idx in range(len(evaliter)):

            batch_items = evaliter.next()
            src_ids = batch_items[0][0].to(device=device)
            src_lengths = batch_items[1]
            acous_feats = batch_items[2][0].to(device=device)
            acous_lengths = batch_items[3]
            labs = batch_items[4][0].to(device=device)
            acous_times = batch_items[5]

            batch_size = src_ids.size(0)
            seq_len = int(max(src_lengths))
            acous_len = int(max(acous_lengths))

            decoder_outputs, decoder_hidden, ret_dict = model(
                src_ids,
                acous_feats=acous_feats,
                acous_times=acous_times,
                is_training=False,
                use_gpu=use_gpu,
                beam_width=beam_width)
            # attention: [32 x ?] (batch_size x src_len x acous_len(key_len))
            # default batch_size = 1
            i = 0
            bsize = test_set.batch_size
            max_seq = test_set.max_seq_len
            vocab_size = len(test_set.src_word2id)

            if not model.add_times:
                # Print sentence by sentence
                # import pdb; pdb.set_trace()
                attention = torch.cat(ret_dict['attention_score'], dim=1)[i]

                seqlist = ret_dict['sequence']
                seqwords = _convert_to_words(seqlist, test_set.src_id2word)
                outline_gen = ' '.join(seqwords[i])
                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                outline_src = ' '.join(srcwords[i])
                print('SRC: {}'.format(outline_src))
                print('GEN: {}'.format(outline_gen))

                # plotting
                # import pdb; pdb.set_trace()
                loc_eos_k = srcwords[i].index('</s>') + 1
                print('eos_k: {}'.format(loc_eos_k))
                # loc_eos_m = seqwords[i].index('</s>') + 1
                loc_eos_m = len(seqwords[i])
                print('eos_m: {}'.format(loc_eos_m))

                att_score_trim = attention[:
                                           loc_eos_m, :]  #each row (each query) sum up to 1
                print('att size: {}'.format(att_score_trim.size()))
                # print('\n')

                choice = input('Plot or not ? - y/n\n')
                if choice:
                    if choice.lower()[0] == 'y':
                        print('plotting ...')
                        plot_dir = os.path.join(plot_path,
                                                '{}.png'.format(count))
                        src = srcwords[i][:loc_eos_m]
                        gen = seqwords[i][:loc_eos_m]

                        # x-axis: acous; y-axis: src
                        plot_attention(att_score_trim.numpy(),
                                       plot_dir,
                                       gen,
                                       words_right=src)  # no ref
                        count += 1
                        input('Press enter to continue ...')

            else:
                # import pdb; pdb.set_trace()
                attention = torch.cat(ret_dict['attention_score'], dim=0)

                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                outline_src = ' '.join(srcwords[i])
                print('SRC: {}'.format(outline_src))
                loc_eos_k = srcwords[i].index('</s>') + 1
                print('eos_k: {}'.format(loc_eos_k))
                att_score_trim = attention[:
                                           loc_eos_k, :]  #each row (each query) sum up to 1
                print('att size: {}'.format(att_score_trim.size()))

                choice = input('Plot or not ? - y/n\n')
                if choice:
                    if choice.lower()[0] == 'y':
                        print('plotting ...')
                        plot_dir = os.path.join(plot_path,
                                                '{}.png'.format(count))
                        src = srcwords[i][:loc_eos_k]

                        # x-axis: acous; y-axis: src
                        plot_attention(att_score_trim.numpy(),
                                       plot_dir,
                                       src,
                                       words_right=src)  # no ref
                        count += 1
                        input('Press enter to continue ...')
Ejemplo n.º 10
0
def debug_beam_search(test_set, load_dir, use_gpu, max_seq_len, beam_width):
    """
		with reference tgt given - debug beam search.
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
		Returns:
			accuracy (excluding PAD tokens)
	"""

    # load model
    # latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
    # latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    model.reset_use_gpu(use_gpu)
    model.reset_batch_size(test_set.batch_size)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    if type(test_set.attkey_path) == type(None):
        test_batches, vocab_size = test_set.construct_batches(is_train=False)
    else:
        test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(
            is_train=False)

    model.eval()
    match = 0
    total = 0
    with torch.no_grad():
        for batch in test_batches:

            src_ids = batch['src_word_ids']
            src_lengths = batch['src_sentence_lengths']
            tgt_ids = batch['tgt_word_ids']
            tgt_lengths = batch['tgt_sentence_lengths']
            src_probs = None
            if 'src_ddfd_probs' in batch:
                src_probs = batch['src_ddfd_probs']
                src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

            src_ids = _convert_to_tensor(src_ids, use_gpu)
            tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

            decoder_outputs, decoder_hidden, other = model(
                src_ids,
                tgt_ids,
                is_training=False,
                att_key_feats=src_probs,
                beam_width=beam_width)

            # Evaluation
            seqlist = other['sequence']  # traverse over time not batch
            if beam_width > 1:
                # print('dict:sequence')
                # print(len(seqlist))
                # print(seqlist[0].size())

                full_seqlist = other['topk_sequence']
                # print('dict:topk_sequence')
                # print(len(full_seqlist))
                # print((full_seqlist[0]).size())
                # input('...')
                seqlists = []
                for i in range(beam_width):
                    seqlists.append([seq[:, i] for seq in full_seqlist])

                # print(decoder_outputs[0].size())
                # print('tgt id size {}'.format(tgt_ids.size()))
                # input('...')

                decoder_outputs = decoder_outputs[:-1]
                # print(len(decoder_outputs))

            for step, step_output in enumerate(
                    decoder_outputs):  # loop over time steps
                target = tgt_ids[:, step + 1]
                non_padding = target.ne(PAD)
                # print('step', step)
                # print('target', target)
                # print('hyp', seqlist[step])
                # if beam_width > 1:
                # 	print('full_seqlist', full_seqlist[step])
                # input('...')
                correct = seqlist[step].view(-1).eq(target).masked_select(
                    non_padding).sum().item()
                match += correct
                total += non_padding.sum().item()

            # write to file
            refwords = _convert_to_words_batchfirst(tgt_ids[:, 1:],
                                                    test_set.tgt_id2word)
            seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
            seqwords_list = []
            for i in range(beam_width):
                seqwords_list.append(
                    _convert_to_words(seqlists[i], test_set.tgt_id2word))

            for i in range(len(seqwords)):
                outline_ref = ' '.join(refwords[i])
                print('REF', outline_ref)
                outline_hyp = ' '.join(seqwords[i])
                # print(outline_hyp)
                outline_hyps = []
                for j in range(beam_width):
                    outline_hyps.append(' '.join(seqwords_list[j][i]))
                    print('{}th'.format(j), outline_hyps[-1])

                # skip padding sentences in batch (num_sent % batch_size != 0)
                # if src_lengths[i] == 0:
                # 	continue
                # words = []
                # for word in seqwords[i]:
                # 	if word == '<pad>':
                # 		continue
                # 	elif word == '</s>':
                # 		break
                # 	else:
                # 		words.append(word)
                # if len(words) == 0:
                # 	outline = ''
                # else:
                # 	outline = ' '.join(words)

                input('...')

            sys.stdout.flush()

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total

    return accuracy
Ejemplo n.º 11
0
def translate_acous(test_set,
                    load_dir,
                    test_path_out,
                    use_gpu,
                    max_seq_len,
                    beam_width,
                    seqrev=False):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model.to(device)
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0
    print('total #batches: {}'.format(len(evaliter)))

    f2 = open(os.path.join(test_path_out, 'translate.tsv'), 'w')
    with open(os.path.join(test_path_out, 'translate.txt'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()
                src_ids = batch_items[0][0].to(device=device)
                src_lengths = batch_items[1]
                acous_feats = batch_items[2][0].to(device=device)
                acous_lengths = batch_items[3]
                labs = batch_items[4][0].to(device=device)
                acous_times = batch_items[5]

                batch_size = src_ids.size(0)
                seq_len = int(max(src_lengths))
                acous_len = int(max(acous_lengths))

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    acous_feats=acous_feats,
                    acous_times=acous_times,
                    is_training=False,
                    use_gpu=use_gpu,
                    beam_width=beam_width)

                # memory usage
                mem_kb, mem_mb, mem_gb = get_memory_alloc()
                mem_mb = round(mem_mb, 2)
                print('Memory used: {0:.2f} MB'.format(mem_mb))
                batch_size = src_ids.size(0)

                model.check_var('add_times', False)
                if not model.add_times:

                    # write to file
                    # import pdb; pdb.set_trace()
                    seqlist = other['sequence']
                    seqwords = _convert_to_words(seqlist, test_set.src_id2word)

                    # print las output
                    model.check_var('use_type', 'char')
                    total += len(seqwords)
                    if model.use_type == 'char':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                elif word == '<spc>':
                                    words.append(' ')
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ''.join(words)
                            f.write('{}\n'.format(outline))

                    elif model.use_type == 'word' or model.use_type == 'bpe':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ' '.join(words)
                            f.write('{}\n'.format(outline))

                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                dd_ps = other['classify_prob']

                # print dd output
                for i in range(len(srcwords)):
                    for j in range(len(srcwords[i])):
                        word = srcwords[i][j]
                        prob = dd_ps[i][j].data[0]
                        if word == '<pad>':
                            break
                        elif word == '</s>':
                            break
                        else:
                            f2.write('{}\t{}\n'.format(word, prob))
                    f2.write('\n')

                sys.stdout.flush()

    print('total #sent: {}'.format(total))
    f2.close()
Ejemplo n.º 12
0
    def _evaluate_batches(self, model, batches, dataset):

        model.eval()

        loss = NLLLoss()
        loss.reset()

        match = 0
        total = 0

        out_count = 0
        with torch.no_grad():
            for batch in batches:

                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']
                src_probs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                src_labs = None
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                non_padding_mask_tgt = tgt_ids.data.ne(PAD)
                non_padding_mask_src = src_ids.data.ne(PAD)

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    tgt_ids,
                    is_training=False,
                    att_key_feats=src_probs)

                # Evaluation
                logps = torch.stack(decoder_outputs, dim=1).to(device=device)
                if not self.eval_with_mask:
                    loss.eval_batch(logps.reshape(-1, logps.size(-1)),
                                    tgt_ids[:, 1:].reshape(-1))
                else:
                    loss.eval_batch_with_mask(
                        logps.reshape(-1, logps.size(-1)),
                        tgt_ids[:, 1:].reshape(-1),
                        non_padding_mask_tgt[:, 1:].reshape(-1))

                seqlist = other['sequence']
                seqres = torch.stack(seqlist, dim=1).to(device=device)
                correct = seqres.view(-1).eq(tgt_ids[:,1:].reshape(-1))\
                 .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item()
                match += correct
                total += non_padding_mask_tgt[:, 1:].sum().item()

                if not self.eval_with_mask:
                    loss.norm_term = 1.0 * tgt_ids.size(
                        0) * tgt_ids[:, 1:].size(1)
                else:
                    loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:,
                                                                          1:])
                loss.normalise()

                if out_count < 3:
                    srcwords = _convert_to_words_batchfirst(
                        src_ids, dataset.tgt_id2word)
                    refwords = _convert_to_words_batchfirst(
                        tgt_ids[:, 1:], dataset.tgt_id2word)
                    seqwords = _convert_to_words(seqlist, dataset.tgt_id2word)
                    outsrc = 'SRC: {}\n'.format(' '.join(
                        srcwords[0])).encode('utf-8')
                    outref = 'REF: {}\n'.format(' '.join(
                        refwords[0])).encode('utf-8')
                    outline = 'GEN: {}\n'.format(' '.join(
                        seqwords[0])).encode('utf-8')
                    sys.stdout.buffer.write(outsrc)
                    sys.stdout.buffer.write(outref)
                    sys.stdout.buffer.write(outline)
                    out_count += 1

        att_resloss = 0
        attcls_resloss = 0
        resloss = loss.get_loss()

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total
        torch.cuda.empty_cache()

        losses = {}
        losses['att_loss'] = att_resloss
        losses['attcls_loss'] = attcls_resloss

        return resloss, accuracy, losses