def eval_bs(test_set: Dataset, vocab: Vocab, model: Seq2Seq, params: Params): test_gen = test_set.generator(1, vocab, None, True if params.pointer else False) n_samples = int(params.test_sample_ratio * len(test_set.pairs)) if params.test_save_results and params.model_path_prefix: result_file = tarfile.open(params.model_path_prefix + ".results.tgz", 'w:gz') else: result_file = None model.eval() r1, r2, rl, rsu4 = 0, 0, 0, 0 prog_bar = tqdm(range(1, n_samples + 1)) for i in prog_bar: batch = next(test_gen) scores, file_content = eval_bs_batch(batch, model, vocab, pack_seq=params.pack_seq, beam_size=params.beam_size, min_out_len=params.min_out_len, max_out_len=params.max_out_len, len_in_words=params.out_len_in_words, details=result_file is not None) if file_content: file_content = file_content.encode('utf-8') file_info = tarfile.TarInfo(name='%06d.txt' % i) file_info.size = len(file_content) result_file.addfile(file_info, fileobj=BytesIO(file_content)) if scores: r1 += scores[0]['1_f'] r2 += scores[0]['2_f'] rl += scores[0]['l_f'] rsu4 += scores[0]['su4_f'] prog_bar.set_postfix(R1='%.4g' % (r1 / i * 100), R2='%.4g' % (r2 / i * 100), RL='%.4g' % (rl / i * 100), RSU4='%.4g' % (rsu4 / i * 100))
arg_name = None if arg_name is not None: print("Warning: Argument %s lacks a value and is ignored." % arg_name) dataset = Dataset(p.data_path, max_src_len=p.max_src_len, max_tgt_len=p.max_tgt_len, truncate_src=p.truncate_src, truncate_tgt=p.truncate_tgt) if m is None: v = dataset.build_vocab(p.vocab_size, embed_file=p.embed_file) m = Seq2Seq(v, p) else: v = dataset.build_vocab(p.vocab_size) train_gen = dataset.generator(p.batch_size, v, v, True if p.pointer else False) if p.val_data_path: val_dataset = Dataset(p.val_data_path, max_src_len=p.max_src_len, max_tgt_len=p.max_tgt_len, truncate_src=p.truncate_src, truncate_tgt=p.truncate_tgt) val_gen = val_dataset.generator(p.val_batch_size, v, v, True if p.pointer else False) else: val_gen = None train(train_gen, v, m, p, val_gen, train_status)