Esempio n. 1
0
def main():

    # load config
    parser = argparse.ArgumentParser(description='Evaluation')
    parser = load_arguments(parser)
    args = vars(parser.parse_args())
    config = validate_config(args)

    # load src-tgt pair
    test_path_src = config['test_path_src']
    test_path_tgt = config['test_path_tgt']
    if type(test_path_tgt) == type(None):
        test_path_tgt = test_path_src

    test_path_out = config['test_path_out']
    test_acous_path = config['test_acous_path']
    acous_norm_path = config['acous_norm_path']

    load_dir = config['load']
    max_seq_len = config['max_seq_len']
    batch_size = config['batch_size']
    beam_width = config['beam_width']
    use_gpu = config['use_gpu']
    seqrev = config['seqrev']
    use_type = config['use_type']

    # set test mode
    MODE = config['eval_mode']
    if MODE != 2:
        if not os.path.exists(test_path_out):
            os.makedirs(test_path_out)
        config_save_dir = os.path.join(test_path_out, 'eval.cfg')
        save_config(config, config_save_dir)

    # check device:
    device = check_device(use_gpu)
    print('device: {}'.format(device))

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
    model = resume_checkpoint.model.to(device)
    vocab_src = resume_checkpoint.input_vocab
    vocab_tgt = resume_checkpoint.output_vocab
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # combine model
    if type(config['combine_path']) != type(None):
        model = combine_weights(config['combine_path'])
    # import pdb; pdb.set_trace()

    # load test_set
    test_set = Dataset(
        path_src=test_path_src,
        path_tgt=test_path_tgt,
        vocab_src_list=vocab_src,
        vocab_tgt_list=vocab_tgt,
        use_type=use_type,
        acous_path=test_acous_path,
        seqrev=seqrev,
        acous_norm=config['acous_norm'],
        acous_norm_path=config['acous_norm_path'],
        acous_max_len=6000,  # max 50k for mustc trainset
        max_seq_len_src=900,
        max_seq_len_tgt=900,  # max 2.5k for mustc trainset
        batch_size=batch_size,
        mode='ST',
        use_gpu=use_gpu)

    print('Test dir: {}'.format(test_path_src))
    print('Testset loaded')
    sys.stdout.flush()

    # '{AE|ASR|MT|ST}-{REF|HYP}'
    if len(config['gen_mode'].split('-')) == 2:
        gen_mode = config['gen_mode'].split('-')[0]
        history = config['gen_mode'].split('-')[1]
    elif len(config['gen_mode'].split('-')) == 1:
        gen_mode = config['gen_mode']
        history = 'HYP'

    # add external language model
    lm_mode = config['lm_mode']

    # run eval:
    if MODE == 1:
        translate(test_set,
                  model,
                  test_path_out,
                  use_gpu,
                  max_seq_len,
                  beam_width,
                  device,
                  seqrev=seqrev,
                  gen_mode=gen_mode,
                  lm_mode=lm_mode,
                  history=history)

    elif MODE == 2:  # save combined model
        ckpt = Checkpoint(model=model,
                          optimizer=None,
                          epoch=0,
                          step=0,
                          input_vocab=test_set.vocab_src,
                          output_vocab=test_set.vocab_tgt)
        saved_path = ckpt.save_customise(
            os.path.join(config['combine_path'].strip('/') + '-combine',
                         'combine'))
        log_ckpts(config['combine_path'],
                  config['combine_path'].strip('/') + '-combine')
        print('saving at {} ... '.format(saved_path))

    elif MODE == 3:
        plot_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device)

    elif MODE == 4:
        gather_emb(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)

    elif MODE == 5:
        compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)
Esempio n. 2
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = test_path_src
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	# set test mode: 1 = translate; 2 = plot; 3 = save comb ckpt
	MODE = config['eval_mode']
	if MODE != 3:
		if not os.path.exists(test_path_out):
			os.makedirs(test_path_out)
		config_save_dir = os.path.join(test_path_out, 'eval.cfg')
		save_config(config, config_save_dir)

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
	model = resume_checkpoint.model.to(device)
	vocab_src = resume_checkpoint.input_vocab
	vocab_tgt = resume_checkpoint.output_vocab
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# combine model
	if type(config['combine_path']) != type(None):
		model = combine_weights(config['combine_path'])

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt,
						seqrev=seqrev,
						max_seq_len=900,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Test dir: {}'.format(test_path_src))
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1:
		translate(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	elif MODE == 2: # output posterior
		translate_logp(test_set, model, test_path_out, use_gpu,
			max_seq_len, device, seqrev=seqrev)

	elif MODE == 3: # save combined model
		ckpt = Checkpoint(model=model,
				   optimizer=None, epoch=0, step=0,
				   input_vocab=test_set.vocab_src,
				   output_vocab=test_set.vocab_tgt)
		saved_path = ckpt.save_customise(
			os.path.join(config['combine_path'].strip('/')+'-combine','combine'))
		log_ckpts(config['combine_path'], config['combine_path'].strip('/')+'-combine')
		print('saving at {} ... '.format(saved_path))