def test(args): test_set = Dataset.from_bin_file(args.test_file) assert args.load_model print('load model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) transition_system = params['transition_system'] saved_args = params['args'] saved_args.cuda = args.cuda # set the correct domain from saved arg args.lang = saved_args.lang parser_cls = Registrable.by_name(args.parser) parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda) parser.eval() evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) eval_results, decode_results = evaluation.evaluate( test_set.examples, parser, evaluator, args, verbose=args.verbose, return_decode_result=True) print(eval_results, file=sys.stderr) if args.save_decode_to: pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
def pl_debug(args): test_set = Dataset.from_bin_file(args.test_file) test_set.examples = [x for i,x in enumerate(test_set.examples) if i in debug_idx] assert args.load_model print('load model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) transition_system = params['transition_system'] saved_args = params['args'] saved_args.cuda = args.cuda # set the correct domain from saved arg args.lang = saved_args.lang parser_cls = Registrable.by_name(args.parser) parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda) parser.eval() evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) # decode_results, turning_point = [before_decodes, after_decodes], eval_results, decode_results, debug_info = evaluation.pl_evaluate(test_set.examples, parser, evaluator, args, verbose=args.verbose, return_decode_result=True, debug=True) print(eval_results, file=sys.stderr) # if args.save_decode_to: # pickle.dump(decode_results, open(args.save_decode_to, 'wb')) # dump_debug_info with open("debug_info.txt", "w") as f: for idx, ex, pred_hyps, info in zip(debug_idx, test_set.examples, decode_results, debug_info): predictions = [x.code.replace(" ", "") for x in pred_hyps] f.write("----------------{}------------------\n".format(idx)) f.write("Src: {}\n".format(" ".join(ex.src_sent))) f.write("Tgt: {}\n".format(ex.tgt_code.replace(" ", ""))) f.write("Predictions:\n") pred_results = eval_streg_predictions(predictions, ex) for p, r in zip(predictions, pred_results): f.write("\t{} {}\n".format(r, p.replace(" ", ""))) if info is None: f.write("\n") continue prev_beam, latter_beam = info prev_beam.sort(key=lambda hyp: -hyp.score) latter_beam.sort(key=lambda hyp: -hyp.score) f.write("Beam {}:\n".format(prev_beam[0].t)) for p_hyp in prev_beam: _, partial_ast = partial_asdl_ast_to_streg_ast(p_hyp.tree) f.write("\t{:.2f} {}\n".format(p_hyp.score, partial_ast.debug_form())) f.write("Beam {}:\n".format(latter_beam[0].t)) for p_hyp in latter_beam: _, partial_ast = partial_asdl_ast_to_streg_ast(p_hyp.tree) f.write("\t{:.2f} {}\n".format(p_hyp.score, partial_ast.debug_form())) f.write("\n")
def __init__(self, parser_name, model_path, example_processor_name, beam_size=5, cuda=False): print('load parser from [%s]' % model_path, file=sys.stderr) self.parser = parser = Registrable.by_name(parser_name).load( model_path, cuda=cuda).eval() self.example_processor = Registrable.by_name(example_processor_name)( parser.transition_system) self.beam_size = beam_size
def test(args): assert args.load_model print('load model from [%s]' % args.load_model, file=sys.stderr) """ params = torch.load(args.load_model, map_location=lambda storage, loc: storage) transition_system = params['transition_system'] saved_args = params['args'] saved_args.cuda = args.cuda args.lang = saved_args.lang """ parser_cls = Registrable.by_name(args.parser) model = parser_cls.load(model_path=args.load_model, cuda=args.cuda) decode_results = [] count = 0 hyps = model.parse(beam_size) decoded_hyps = [] for hyp_id, hyp in enumerate(hyps): try: hyp.code = model.transition_system.ast_to_surface_code(hyp.tree) print(hyp.code) decoded_hyps.append(hyp) except: pass """
def test(args): tmpvocab={'null':0} tmpvocab1={'null':0,'unk':1} tmpvocab2={'null':0,'unk':1} tmpvocab3={'null':0,'unk':1} train_set = Dataset.from_bin_file(args.train_file) from dependency import sentencetoadj,sentencetoextra_message valid=0 for example in tqdm.tqdm(train_set.examples): # print(example.src_sent) example.mainnode,example.adj,example.edge,_,isv=sentencetoadj(example.src_sent,tmpvocab) example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,True) valid+=isv # print(example.adj) # a=input('gh') print('bukey',valid) # if args.dev_file: # dev_set = Dataset.from_bin_file(args.dev_file) # else: dev_set = Dataset(examples=[]) test_set = Dataset.from_bin_file(args.test_file) for example in tqdm.tqdm(test_set.examples): # print(example.src_sent) example.mainnode,example.adj,example.edge,_,_=sentencetoadj(example.src_sent,tmpvocab) example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,False) bertmodels=BertModel.from_pretrained('./pretrained_models/base-uncased/') tokenizer=BertTokenizer.from_pretrained('./pretrained_models/bert-base-uncased-vocab.txt') assert args.load_model print('load model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) transition_system = params['transition_system'] saved_args = params['args'] saved_args.cuda = args.cuda # set the correct domain from saved arg args.lang = saved_args.lang parser_cls = Registrable.by_name(args.parser) parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda,bert=bertmodels,tmpv=tmpvocab,v1=tmpvocab1,v2=tmpvocab2,v3=tmpvocab3) parser.tokenizer=tokenizer parser.eval() evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) eval_results, decode_results = evaluation.evaluate(test_set.examples, parser, evaluator, args, verbose=args.verbose, return_decode_result=True) print(eval_results, file=sys.stderr) if args.save_decode_to: pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
def test(cfg: argparse.Namespace): cli_logger.info("=== Testing ===") experiment_base_dir = prologue(cfg) test_set = Dataset.from_bin_file(cfg.test_file) cli_logger.info(f"Loaded test file from [{cfg.test_file}]") params = torch.load(cfg.load_model, map_location=lambda storage, loc: storage) cli_logger.info(f"Loaded model from [{cfg.load_model}]") transition_system = params['transition_system'] cli_logger.info(f"Loaded transition system [{type(transition_system)}]") saved_args: argparse.Namespace = params['args'] saved_args.cuda = cfg.cuda # FIXME ?? set the correct domain from saved arg cfg.lang = saved_args.lang dump_cfg(experiment_base_dir + "/was_trained_with.txt", cfg=saved_args.__dict__) parser_cls = Registrable.by_name(cfg.parser) parser = parser_cls.load(model_path=cfg.load_model, cuda=cfg.cuda) parser.eval() cli_logger.info(f"Loaded parser model [{cfg.parser}]") evaluator = Registrable.by_name(cfg.evaluator)(transition_system, args=cfg) cli_logger.info(f"Loaded evaluator [{cfg.evaluator}]") # Do the evaluation eval_results, decoded_results = evaluation.evaluate( examples=test_set.examples, model=parser, evaluator=evaluator, args=cfg, verbose=cfg.verbose, return_decoded_result=True ) cli_logger.info(eval_results) if cfg.save_decode_to: pickle.dump(decoded_results, open(cfg.save_decode_to, 'wb')) cli_logger.info(f"Saved decoded results to [{cfg.save_decode_to}]") epilogue(cfg) cli_logger.info("=== Done ===")
def train(args): """Maximum Likelihood Estimation""" # load in train/dev set train_set = Dataset.from_bin_file(args.train_file) vocab = pickle.load(open(args.vocab, 'rb')) grammar = ASDLGrammar.from_text(open(args.asdl_file).read()) transition_system = Registrable.by_name(args.transition_system)(grammar) parser_cls = Registrable.by_name(args.parser) # TODO: add arg model = parser_cls(args, vocab, transition_system) model.train() if args.cuda: model.cuda() optimizer_cls = eval('torch.optim.%s' % args.optimizer) # FIXME: this is evil! optimizer = optimizer_cls(model.parameters(), lr=args.lr) nn_utils.glorot_init(model.parameters()) print('begin training, %d training examples' % len(train_set), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = 0. while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [e for e in batch_examples if len(e.tgt_actions)] train_iter += 1 optimizer.zero_grad() ret_val = model.score(batch_examples) loss = -ret_val # print(loss.data) loss_val = torch.sum(loss).data report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: log_str = '[Iter %d] loss=%.5f' % (train_iter, report_loss / report_examples) print(log_str, file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) model_file = args.save_to + '.iter%d.bin' % train_iter print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) if epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0)
def train(args): """Maximum Likelihood Estimation""" # load in train/dev set train_set = Dataset.from_bin_file(args.train_file) if args.dev_file: dev_set = Dataset.from_bin_file(args.dev_file) else: dev_set = Dataset(examples=[]) vocab = pickle.load(open(args.vocab, 'rb')) grammar = ASDLGrammar.from_text(open(args.asdl_file).read()) transition_system = Registrable.by_name(args.transition_system)(grammar) parser_cls = Registrable.by_name(args.parser) # TODO: add arg model = parser_cls(args, vocab, transition_system) model.train() evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) if args.cuda: model.cuda() optimizer_cls = eval('torch.optim.%s' % args.optimizer) # FIXME: this is evil! optimizer = optimizer_cls(model.parameters(), lr=args.lr) if args.uniform_init: print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr) nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters()) elif args.glorot_init: print('use glorot initialization', file=sys.stderr) nn_utils.glorot_init(model.parameters()) # load pre-trained word embedding (optional) if args.glove_embed_path: print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr) glove_embedding = GloveHelper(args.glove_embed_path) glove_embedding.load_to(model.src_embed, vocab.source) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = report_sup_att_loss = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step] train_iter += 1 optimizer.zero_grad() ret_val = model.score(batch_examples) loss = -ret_val[0] # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) if args.sup_attention: att_probs = ret_val[1] if att_probs: sup_att_loss = -torch.log(torch.cat(att_probs)).mean() sup_att_loss_val = sup_att_loss.data[0] report_sup_att_loss += sup_att_loss_val loss += sup_att_loss loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: log_str = '[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples) if args.sup_attention: log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples) report_sup_att_loss = 0. print(log_str, file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) if args.save_all_models: model_file = args.save_to + '.iter%d.bin' % train_iter print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # perform validation if args.dev_file: if epoch % args.valid_every_epoch == 0: print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, evaluator, args, verbose=True, eval_top_pred_only=args.eval_top_pred_only) dev_score = eval_results[evaluator.default_metric] print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % ( epoch, eval_results, evaluator.default_metric, dev_score, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_score > max(history_dev_scores) history_dev_scores.append(dev_score) else: is_better = True if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch: lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('decay learning rate to %f' % lr, file=sys.stderr) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr if is_better: patience = 0 model_file = args.save_to + '.bin' print('save the current model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif patience < args.patience and epoch >= args.lr_decay_after_epoch: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) if patience >= args.patience and epoch >= args.lr_decay_after_epoch: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam(model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def train(args): """Maximum Likelihood Estimation""" tokenizer=BertTokenizer.from_pretrained('./pretrained_models/bert-base-uncased-vocab.txt') bertmodels=BertModel.from_pretrained('./pretrained_models/base-uncased/') print(len(tokenizer.vocab)) # load in train/dev set tmpvocab={'null':0} tmpvocab1={'null':0,'unk':1} tmpvocab2={'null':0,'unk':1} tmpvocab3={'null':0,'unk':1} train_set = Dataset.from_bin_file(args.train_file) from dependency import sentencetoadj,sentencetoextra_message valid=0 # for example in tqdm.tqdm(train_set.examples): ## print(example.src_sent) ## example.mainnode,example.adj,example.edge,_,isv=sentencetoadj(example.src_sent,tmpvocab) # example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,True) ## valid+=isv ## print(example.src_sent) ### print( example.contains,example.pos,example.ner,example.types) ## a=input('gh') # print('bukey',valid) if args.dev_file: dev_set = Dataset.from_bin_file(args.dev_file) else: dev_set = Dataset(examples=[]) # for example in tqdm.tqdm(dev_set.examples): ## print(example.src_sent) ## example.mainnode,example.adj,example.edge,_,_=sentencetoadj(example.src_sent,tmpvocab) # example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,False) vocab = pickle.load(open(args.vocab, 'rb')) print(len(vocab.source)) vocab.source.copyandmerge(tokenizer) print(len(vocab.source)) # tokenizer.update(vocab.source.word2id) # print(len(tokenizer.vocab)) # print(tokenizer.vocab['metodiev']) # bertmodels.resize_token_embeddings(len(vocab.source)) grammar = ASDLGrammar.from_text(open(args.asdl_file).read()) transition_system = Registrable.by_name(args.transition_system)(grammar) parser_cls = Registrable.by_name(args.parser) # TODO: add arg model = parser_cls(args, vocab, transition_system,tmpvocab,tmpvocab1,tmpvocab2,tmpvocab3) model.train() model.tokenizer=tokenizer evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) if args.cuda: model.cuda() if args.uniform_init: print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr) nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters()) elif args.glorot_init: print('use glorot initialization', file=sys.stderr) nn_utils.glorot_init(model.parameters()) # load pre-trained word embedding (optional) if args.glove_embed_path: print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr) glove_embedding = GloveHelper(args.glove_embed_path) glove_embedding.load_to(model.src_embed, vocab.source) print([name for name,_ in model.named_parameters()]) model.bert_model=bertmodels # print([name for name,_ in model.named_parameters()]) model.train() if args.cuda: model.cuda() # return 0 # a=input('haha') optimizer_cls = eval('torch.optim.%s' % args.optimizer) # FIXME: this is evil! # parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name or 'embeddings' in name] parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name] parameters1=[p for name,p in model.named_parameters() if 'bert_model' in name] optimizer = optimizer_cls(parameters, lr=args.lr) optimizer1 = optimizer_cls(parameters1, lr=0.00001) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) is_better = False epoch = train_iter = 0 report_loss = report_examples = report_sup_att_loss = 0. report_loss1=0 history_dev_scores = [] num_trial = patience = 0 while True: if epoch>40:break epoch += 1 epoch_begin = time.time() model.train() for batch_examples in tqdm.tqdm(train_set.batch_iter(batch_size=args.batch_size, shuffle=True)): def process(header,src_sent,tokenizer): length1=len(header) flat_src=[] for item in src_sent: flat_src.extend(tokenizer._tokenize(item)) flat=[token for item in header for token in item.tokens] flat_head=[] for item in flat: flat_head.extend(tokenizer._tokenize(item)) # length2=len(flat)+length1+len(src_sent) length2=len(flat_head)+length1+len(flat_src) print(src_sent) print([item.tokens for item in header]) print(flat_src) print(flat) a=input('hahaha') return length2<130 batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step and process(e.table.header,e.src_sent,tokenizer)] train_iter += 1 optimizer.zero_grad() optimizer1.zero_grad() # params1=model.named_parameters() # print([param for param,_ in params1]) # params=model.rnns.named_parameters() # print([param for param,_ in params]) # print([type(param.grad) for _,param in model.rnns.named_parameters()]) # a=input('ghh;') ret_val,_ = model.score(batch_examples) loss = -ret_val[0] loss1=ret_val[1] # print(loss.data) loss_val = torch.sum(loss).data.item() report_loss += loss_val report_loss1 += 1.0*torch.sum(ret_val[2]) report_examples += len(batch_examples) loss = torch.mean(loss)+0*loss1+0*torch.mean(ret_val[2]) if args.sup_attention: att_probs = ret_val[1] if att_probs: sup_att_loss = -torch.log(torch.cat(att_probs)).mean() sup_att_loss_val = sup_att_loss.data.item() report_sup_att_loss += sup_att_loss_val loss += sup_att_loss loss.backward() # print([type(param.grad) for _,param in model.rnns.named_parameters()]) # # print([type(param.grad) for param in model.parameters()]) # a=input('ghh;') # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() optimizer1.step() loss=None if train_iter % args.log_every == 0: log_str = '[Iter %d] encoder loss=%.5f,coverage loss=%.5f' % (train_iter, report_loss / report_examples,report_loss1 / report_examples) if args.sup_attention: log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples) report_sup_att_loss = 0. print(log_str, file=sys.stderr) report_loss = report_examples = 0. report_loss1=0 print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) if args.save_all_models: model_file = args.save_to + '.iter%d.bin' % train_iter print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # perform validation if args.dev_file and epoch>=6: # a=input('gh') if epoch % args.valid_every_epoch == 0: print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, evaluator, args, verbose=True, eval_top_pred_only=args.eval_top_pred_only) dev_score = eval_results[evaluator.default_metric] print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % ( epoch, eval_results, evaluator.default_metric, dev_score, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_score > max(history_dev_scores) history_dev_scores.append(dev_score) print('[Epoch %d] begin validation2' % epoch, file=sys.stderr) # eval_start = time.time() # eval_results = evaluation.evaluate(dev_set.examples[:2000], model, evaluator, args, # verbose=True, eval_top_pred_only=args.eval_top_pred_only) # dev_score = eval_results[evaluator.default_metric] # # print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % ( # epoch, eval_results, # evaluator.default_metric, # dev_score, # time.time() - eval_start), file=sys.stderr) # is_better = history_dev_scores == [] or dev_score > max(history_dev_scores) # history_dev_scores.append(dev_score) else: is_better = True if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch: lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('decay learning rate to %f' % lr, file=sys.stderr) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr if is_better: patience = 0 model_file = args.save_to + '.bin' print('save the current model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif patience < args.patience and epoch >= args.lr_decay_after_epoch: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) if patience >= args.patience and epoch >= args.lr_decay_after_epoch: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) a=input('hj') exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam(model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def train_rl(args): test_set = Dataset.from_bin_file(args.test_file) assert args.load_model train_set = Dataset.from_bin_file(args.train_file) print('load model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) transition_system = params['transition_system'] saved_args = params['args'] saved_args.cuda = args.cuda # set the correct domain from saved arg args.lang = saved_args.lang def getnew(model,e,shows): example=copy.deepcopy(e) hyps,extra = model.sample(example.src_sent, context=example.table,show=shows) if(len(hyps)==0): return e actions=hyps[0].actions example.tgt_actions = get_action_infos(example.src_sent, actions, force_copy=True) example.actions=actions example.extra=extra if(shows):print('haha') return example parser_cls = Registrable.by_name(args.parser) parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda) parser.train() model=parser evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) if args.cuda: model.cuda() optimizer_cls = eval('torch.optim.%s' % args.optimizer) # FIXME: this is evil! optimizer = optimizer_cls(model.parameters(), lr=args.lr) for e in train_set.examples: e.actions=[item.action for item in e.tgt_actions] train_iter=1 if(True): parser.eval() # print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(test_set.examples, model, evaluator, args, verbose=True, eval_top_pred_only=args.eval_top_pred_only) dev_score = eval_results[evaluator.default_metric] print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % ( train_iter, eval_results, evaluator.default_metric, dev_score, time.time() - eval_start), file=sys.stderr) a=input('df') print(len(train_set.examples)/64) bestscore=0 for _ in range(100): model.train() show=False for batch_examples in tqdm.tqdm(list(train_set.batch_iter(batch_size=args.batch_size, shuffle=False))[0:]): model.eval() # if(train_iter>=23): # show=True batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step] newbatch_examples=[getnew(parser,e,show) for e in batch_examples] # print(newbatch_examples[0].extra) # model.train() means=[] model.eval() for example in batch_examples: hyps = model.parse(example.src_sent, context=example.table, beam_size=args.beam_size) if(len(hyps)==0): means.append(None) else:means.append(hyps[0]) import asdl weight=asdl.transition_system.get_scores(newbatch_examples,batch_examples,means) # if(train_iter>1000): # for i in range(len(batch_examples)): # print('true'+str(batch_examples[i].actions)) # print('chou'+str(newbatch_examples[i].actions)) # print('max'+str(means[i].actions)) # print(weight[i]) # a=input('hj') # if(train_iter>=22): # print(weight) ## a=input('jj') train_iter += 1 model.train() optimizer.zero_grad() ret_val,extra = model.score(batch_examples,weights=weight) ret_val1,extra1 = model.score(batch_examples) # print([torch.exp(item) for item in extra[0]]) # a=input('gh') loss = -ret_val[0] loss1= -ret_val1[0] # print(loss.data) loss_val = torch.sum(loss).data.item() # report_loss += loss_val # report_examples += len(batch_examples) loss = torch.mean(loss)+torch.mean(loss1)*0.0 # if args.sup_attention: # att_probs = ret_val[1] # if att_probs: # sup_att_loss = -torch.log(torch.cat(att_probs)).mean() # sup_att_loss_val = sup_att_loss.data.item() ## report_sup_att_loss += sup_att_loss_val # # loss += sup_att_loss loss.backward() # # # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) # optimizer.step() # print(train_iter) if train_iter % 200 == 0: parser.eval() # print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(test_set.examples[:2000], model, evaluator, args, verbose=True, eval_top_pred_only=args.eval_top_pred_only) dev_score = eval_results[evaluator.default_metric] if bestscore < dev_score: bestscore =dev_score print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % ( train_iter, eval_results, evaluator.default_metric, dev_score, time.time() - eval_start), file=sys.stderr) print('hhah'+str(bestscore)) evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args) eval_results, decode_results = evaluation.evaluate(test_set.examples, parser, evaluator, args, verbose=args.verbose, return_decode_result=True) print(eval_results, file=sys.stderr) if args.save_decode_to: pickle.dump(decode_results, open(args.save_decode_to, 'wb'))