def main(): # load config parser = argparse.ArgumentParser(description='Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = config['test_path_tgt'] if type(test_path_tgt) == type(None): test_path_tgt = test_path_src test_path_out = config['test_path_out'] test_acous_path = config['test_acous_path'] acous_norm_path = config['acous_norm_path'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] # set test mode MODE = config['eval_mode'] if MODE != 2: if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load model latest_checkpoint_path = load_dir resume_checkpoint = Checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model.to(device) vocab_src = resume_checkpoint.input_vocab vocab_tgt = resume_checkpoint.output_vocab print('Model dir: {}'.format(latest_checkpoint_path)) print('Model laoded') # combine model if type(config['combine_path']) != type(None): model = combine_weights(config['combine_path']) # import pdb; pdb.set_trace() # load test_set test_set = Dataset( path_src=test_path_src, path_tgt=test_path_tgt, vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt, use_type=use_type, acous_path=test_acous_path, seqrev=seqrev, acous_norm=config['acous_norm'], acous_norm_path=config['acous_norm_path'], acous_max_len=6000, # max 50k for mustc trainset max_seq_len_src=900, max_seq_len_tgt=900, # max 2.5k for mustc trainset batch_size=batch_size, mode='ST', use_gpu=use_gpu) print('Test dir: {}'.format(test_path_src)) print('Testset loaded') sys.stdout.flush() # '{AE|ASR|MT|ST}-{REF|HYP}' if len(config['gen_mode'].split('-')) == 2: gen_mode = config['gen_mode'].split('-')[0] history = config['gen_mode'].split('-')[1] elif len(config['gen_mode'].split('-')) == 1: gen_mode = config['gen_mode'] history = 'HYP' # add external language model lm_mode = config['lm_mode'] # run eval: if MODE == 1: translate(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev, gen_mode=gen_mode, lm_mode=lm_mode, history=history) elif MODE == 2: # save combined model ckpt = Checkpoint(model=model, optimizer=None, epoch=0, step=0, input_vocab=test_set.vocab_src, output_vocab=test_set.vocab_tgt) saved_path = ckpt.save_customise( os.path.join(config['combine_path'].strip('/') + '-combine', 'combine')) log_ckpts(config['combine_path'], config['combine_path'].strip('/') + '-combine') print('saving at {} ... '.format(saved_path)) elif MODE == 3: plot_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device) elif MODE == 4: gather_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device) elif MODE == 5: compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len, device)
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = test_path_src test_path_out = config['test_path_out'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] # set test mode: 1 = translate; 2 = plot; 3 = save comb ckpt MODE = config['eval_mode'] if MODE != 3: if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load model latest_checkpoint_path = load_dir resume_checkpoint = Checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model.to(device) vocab_src = resume_checkpoint.input_vocab vocab_tgt = resume_checkpoint.output_vocab print('Model dir: {}'.format(latest_checkpoint_path)) print('Model laoded') # combine model if type(config['combine_path']) != type(None): model = combine_weights(config['combine_path']) # load test_set test_set = Dataset(test_path_src, test_path_tgt, vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt, seqrev=seqrev, max_seq_len=900, batch_size=batch_size, use_gpu=use_gpu, use_type=use_type) print('Test dir: {}'.format(test_path_src)) print('Testset loaded') sys.stdout.flush() # run eval if MODE == 1: translate(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev) elif MODE == 2: # output posterior translate_logp(test_set, model, test_path_out, use_gpu, max_seq_len, device, seqrev=seqrev) elif MODE == 3: # save combined model ckpt = Checkpoint(model=model, optimizer=None, epoch=0, step=0, input_vocab=test_set.vocab_src, output_vocab=test_set.vocab_tgt) saved_path = ckpt.save_customise( os.path.join(config['combine_path'].strip('/')+'-combine','combine')) log_ckpts(config['combine_path'], config['combine_path'].strip('/')+'-combine') print('saving at {} ... '.format(saved_path))