def main(): opt = parse_args() logger = get_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_src, 'src') conversation_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_conv, 'conversation') tgt_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_tgt, 'tgt') logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of conversation features: %d." % conversation_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) logger.info("Building `Fields` object...") fields = onmt.io.get_fields(opt.data_type, src_nfeats, conversation_nfeats, tgt_nfeats) logger.info("Building & saving training data...") train_dataset_files = build_save_dataset('train', fields, opt, logger) logger.info("Building & saving vocabulary...") build_save_vocab(train_dataset_files, fields, opt, logger) logger.info("Building & saving validation data...") build_save_dataset('valid', fields, opt, logger)
opt) logger.info("\nMatching: ") match_percent = [_['match'] / (_['match'] + _['miss']) * 100 for _ in [enc_count, dec_count]] logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % (enc_count['match'], enc_count['miss'], match_percent[0])) logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % (dec_count['match'], dec_count['miss'], match_percent[1])) logger.info("\nFiltered embeddings:") logger.info("\t* enc: ", filtered_enc_embeddings.size()) logger.info("\t* dec: ", filtered_dec_embeddings.size()) enc_output_file = opt.output_file + ".enc.pt" dec_output_file = opt.output_file + ".dec.pt" logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" % (enc_output_file, dec_output_file)) torch.save(filtered_enc_embeddings, enc_output_file) torch.save(filtered_dec_embeddings, dec_output_file) logger.info("\nDone.") """ if __name__ == "__main__": logger = get_logger('embeddings_to_torch.log') main()
from __future__ import division, unicode_literals import argparse from onmt.translate.Translator import make_translator from onmt.Utils import get_logger import onmt.io import onmt.translate import onmt import onmt.ModelConstructor import onmt.modules import onmt.opts def main(opt): translator = make_translator(opt, report_score=True, logger=logger) translator.translate(opt.src_dir, opt.src, opt.tgt, opt.batch_size, opt.attn_debug) if __name__ == "__main__": parser = argparse.ArgumentParser( description='translate.py', formatter_class=argparse.ArgumentDefaultsHelpFormatter) onmt.opts.add_md_help_argument(parser) onmt.opts.translate_opts(parser) opt = parser.parse_args() logger = get_logger(opt.log_file) main(opt)
pass if os.path.isdir(tmp_dir): shutil.rmtree(tmp_dir) def rouge_results_to_str(results_dict): return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format( results_dict["rouge_1_f_score"] * 100, results_dict["rouge_2_f_score"] * 100, results_dict["rouge_3_f_score"] * 100, results_dict["rouge_l_f_score"] * 100, results_dict["rouge_su*_f_score"] * 100) if __name__ == "__main__": logger = get_logger('test_rouge.log') parser = argparse.ArgumentParser() parser.add_argument('-c', type=str, default="candidate.txt", help='candidate file') parser.add_argument('-r', type=str, default="reference.txt", help='reference file') args = parser.parse_args() if args.c.upper() == "STDIN": args.c = sys.stdin results_dict = test_rouge(args.c, args.r) logger.info(rouge_results_to_str(results_dict))
def main(): opt = parse_args() logger = get_logger(opt.log_file) logger.info("Extracting features...") random.seed(13) # load the training data! if opt.dataset.lower() == 'e2e': parser = DatasetParser('data/e2e/trainset.csv', 'data/e2e/devset.csv', 'data/e2e/testset_w_refs.csv', 'E2E', opt) if opt.name.lower() != 'def': parser.dataset_name = opt.name.lower() for predicate in parser.trainingInstances: random.shuffle(parser.trainingInstances[predicate]) elif opt.dataset.lower() == 'webnlg': parser = DatasetParser('data/webNLG_challenge_data/train', 'data/webNLG_challenge_data/dev', False, 'webNLG', opt) if opt.name.lower() != 'def': parser.dataset_name = opt.name.lower() for predicate in parser.trainingInstances: random.shuffle(parser.trainingInstances[predicate]) elif opt.dataset.lower() == 'sfhotel': parser = DatasetParser('data/sfx_data/sfxhotel/train.json', 'data/sfx_data/sfxhotel/valid.json', 'data/sfx_data/sfxhotel/test.json', 'SFHotel', opt) if opt.name.lower() != 'def': parser.dataset_name = opt.name.lower() for predicate in parser.trainingInstances: random.shuffle(parser.trainingInstances[predicate]) gen_templ = parser.get_onmt_file_templ(opt) train_src_templ, train_tgt_templ, train_eval_refs_templ, valid_src_templ, valid_tgt_templ, valid_eval_refs_templ, test_src_templ, test_tgt_templ, test_eval_refs_templ = parser.get_onmt_file_templs( gen_templ) for predicate in parser.predicates: opt.file_templ = gen_templ.format(predicate) opt.train_src = train_src_templ.format(predicate) opt.train_tgt = train_tgt_templ.format(predicate) opt.valid_src = valid_src_templ.format(predicate) opt.valid_tgt = valid_tgt_templ.format(predicate) src_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_src, 'src') tgt_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_tgt, 'tgt') logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) logger.info("Building {} `Fields` object...".format(predicate)) fields = onmt.io.get_fields(opt.data_type, src_nfeats, tgt_nfeats) logger.info("Building & saving {} training data...".format(predicate)) train_dataset_files = build_save_dataset('train', fields, opt, logger) logger.info("Building & saving {} vocabulary...".format(predicate)) build_save_vocab(train_dataset_files, fields, opt, logger) logger.info( "Building & saving {} validation data...".format(predicate)) build_save_dataset('valid', fields, opt, logger) if parser.trainingInstances: for predicate in parser.trainingInstances: print("Training data size for {}: {}".format( predicate, len(parser.trainingInstances[predicate]))) if parser.developmentInstances: for predicate in parser.developmentInstances: print("Validation data size for {}: {}".format( predicate, len(parser.developmentInstances[predicate]))) if parser.testingInstances: for predicate in parser.testingInstances: print("Test data size for {}: {}".format( predicate, len(parser.testingInstances[predicate]))) print("-----------------------")
def forward(self, input): laplacian = input.exp() + self.eps output = input.clone() for b in range(input.size(0)): lap = laplacian[b].masked_fill( torch.eye(input.size(1)).cuda().ne(0), 0) lap = -lap + torch.diag(lap.sum(0)) # store roots on diagonal lap[0] = input[b].diag().exp() inv_laplacian = lap.inverse() factor = inv_laplacian.diag().unsqueeze(1)\ .expand_as(input[b]).transpose(0, 1) term1 = input[b].exp().mul(factor).clone() term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() term1[:, 0] = 0 term2[0] = 0 output[b] = term1 - term2 roots_output = input[b].diag().exp().mul( inv_laplacian.transpose(0, 1)[0]) output[b] = output[b] + torch.diag(roots_output) return output if __name__ == "__main__": logger = get_logger('StructuredAttention.log') dtree = MatrixTree() q = torch.rand(1, 5, 5).cuda() marg = dtree.forward(q) logger.info(marg.sum(1))
fields = onmt.io.load_fields_from_vocab(checkpoint['vocab']) model_opt = checkpoint['opt'] for arg in dummy_opt.__dict__: if arg not in model_opt: model_opt.__dict__[arg] = dummy_opt.__dict__[arg] model = onmt.ModelConstructor.make_base_model(model_opt, fields, use_gpu(opt), checkpoint) encoder = model.encoder decoder = model.decoder encoder_embeddings = encoder.embeddings.word_lut.weight.data.tolist() decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist() logger.info("Writing source embeddings") write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) logger.info("Writing target embeddings") write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) logger.info('... done.') logger.info('Converting model...') if __name__ == "__main__": logger = get_logger('extract_embeddings.log') main()