def main(): print('Loading data from "%s"' % opt.data) dataset = torch.load(opt.data) supervised_data = lib.Dataset(dataset["train_xe"], opt.batch_size, opt.cuda, eval=False) bandit_data = lib.Dataset(dataset["train_pg"], opt.batch_size, opt.cuda, eval=False) valid_data = lib.Dataset(dataset["valid"], opt.batch_size, opt.cuda, eval=True) test_data = lib.Dataset(dataset["test"], opt.batch_size, opt.cuda, eval=True) dicts = dataset["dicts"] print(" * vocabulary size. source = %d; target = %d" % (dicts["src"].size(), dicts["tgt"].size())) print(" * number of XENT training sentences. %d" % len(dataset["train_xe"]["src"])) print(" * number of PG training sentences. %d" % len(dataset["train_pg"]["src"])) print(" * maximum batch size. %d" % opt.batch_size) print("Building model...") use_critic = opt.start_reinforce is not None if opt.load_from is None: model, optim = create_model(lib.NMTModel, dicts, dicts["tgt"].size()) checkpoint = None else: print("Loading from checkpoint at %s" % opt.load_from) checkpoint = torch.load(opt.load_from) model = checkpoint["model"] optim = checkpoint["optim"] opt.start_epoch = checkpoint["epoch"] + 1 # GPU. if opt.cuda: model.cuda(opt.gpus[0]) # Start reinforce training immediately. if opt.start_reinforce == -1: opt.start_decay_at = opt.start_epoch opt.start_reinforce = opt.start_epoch # Check if end_epoch is large enough. if use_critic: assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \ opt.end_epoch, "Please increase -end_epoch to perform pretraining!" nParams = sum([p.nelement() for p in model.parameters()]) print("* number of parameters: %d" % nParams) # Metrics. metrics = {} metrics["nmt_loss"] = lib.Loss.weighted_xent_loss metrics["critic_loss"] = lib.Loss.weighted_mse metrics["sent_reward"] = lib.Reward.sentence_bleu metrics["corp_reward"] = lib.Reward.corpus_bleu if opt.pert_func is not None: opt.pert_func = lib.PertFunction(opt.pert_func, opt.pert_param) # Evaluate model on heldout dataset. if opt.eval: evaluator = lib.Evaluator(model, metrics, dicts, opt) # On validation set. pred_file = opt.load_from.replace(".pt", ".valid.pred") evaluator.eval(valid_data, pred_file) # On test set. pred_file = opt.load_from.replace(".pt", ".test.pred") evaluator.eval(test_data, pred_file) elif opt.eval_sample: opt.no_update = True critic, critic_optim = create_critic(checkpoint, dicts, opt) reinforce_trainer = lib.ReinforceTrainer(model, critic, bandit_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_epoch, opt.start_epoch, False) elif opt.sup_train_on_bandit: optim.set_lr(opt.reinforce_lr) xent_trainer = lib.Trainer(model, bandit_data, test_data, metrics, dicts, optim, opt) xent_trainer.train(opt.start_epoch, opt.start_epoch) else: print("theek hai") xent_trainer = lib.Trainer(model, supervised_data, valid_data, metrics, dicts, optim, opt) if use_critic: start_time = time.time() # Supervised training. xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time) # Create critic here to not affect random seed. critic, critic_optim = create_critic(checkpoint, dicts, opt) # Pretrain critic. if opt.critic_pretrain_epochs > 0: reinforce_trainer = lib.ReinforceTrainer(model, critic, supervised_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_reinforce, opt.start_reinforce + opt.critic_pretrain_epochs - 1, True, start_time) # Reinforce training. reinforce_trainer = lib.ReinforceTrainer(model, critic, bandit_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_reinforce + opt.critic_pretrain_epochs, opt.end_epoch, False, start_time) # Supervised training only. else: xent_trainer.train(opt.start_epoch, opt.end_epoch)
def main(): print("Start...") global opt opt = get_opt() # Set seed torch.manual_seed(opt.seed) np.random.seed(opt.seed) random.seed(opt.seed) opt.cuda = len(opt.gpus) if opt.save_dir and not os.path.exists(opt.save_dir): os.makedirs(opt.save_dir) if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with -gpus 1") if opt.cuda: cuda.set_device(opt.gpus[0]) torch.cuda.manual_seed(opt.seed) dicts, supervised_data, rl_data, valid_data, test_data, vis_data = load_data(opt) print("Building model...") use_critic = opt.start_reinforce is not None print("use_critic: ", use_critic) if opt.load_from is None: if opt.data_type == 'code': model, optim = create_model(lib.Tree2SeqModel, dicts, dicts["tgt"].size()) elif opt.data_type == 'text': model, optim = create_model(lib.Seq2SeqModel, dicts, dicts["tgt"].size()) elif opt.data_type == 'hybrid': model, optim = create_model(lib.Hybrid2SeqModel, dicts, dicts["tgt"].size()) checkpoint = None print("model: ", model) print("optim: ", optim) else: print("Loading from checkpoint at %s" % opt.load_from) checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) model = checkpoint["model"] optim = checkpoint["optim"] opt.start_epoch = checkpoint["epoch"] + 1 # GPU. if opt.cuda: model.cuda(opt.gpus[0]) # Start reinforce training immediately. print("opt.start_reinforce: ", opt.start_reinforce) if opt.start_reinforce == -1: opt.start_decay_at = opt.start_epoch opt.start_reinforce = opt.start_epoch # Check if end_epoch is large enough. if use_critic: assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \ opt.end_epoch, "Please increase -end_epoch to perform pretraining!" nParams = sum([p.nelement() for p in model.parameters()]) print("* number of parameters: %d" % nParams) # Metrics. metrics = {} metrics["xent_loss"] = lib.Loss.weighted_xent_loss metrics["critic_loss"] = lib.Loss.weighted_mse metrics["sent_reward"] = lib.Reward.sentence_bleu metrics["corp_reward"] = lib.Reward.corpus_bleu if opt.pert_func is not None: opt.pert_func = lib.PertFunction(opt.pert_func, opt.pert_param) print("opt.eval: ", opt.eval) print("opt.eval_sample: ", opt.eval_sample) # Evaluate model on heldout dataset. if opt.eval: evaluator = lib.Evaluator(model, metrics, dicts, opt) # On validation set. if opt.var_length: pred_file = opt.load_from.replace(".pt", ".valid.pred.var"+opt.var_type) else: pred_file = opt.load_from.replace(".pt", ".valid.pred") evaluator.eval(valid_data, pred_file) # On test set. if opt.var_length: pred_file = opt.load_from.replace(".pt", ".test.pred.var"+opt.var_type) else: pred_file = opt.load_from.replace(".pt", ".test.pred") evaluator.eval(test_data, pred_file) elif opt.eval_one: print("eval_one..") evaluator = lib.Evaluator(model, metrics, dicts, opt) # On test set. pred_file = opt.load_from.replace(".pt", ".test_one.pred") evaluator.eval(vis_data, pred_file) elif opt.eval_sample: opt.no_update = True critic, critic_optim = create_critic(checkpoint, dicts, opt) reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_epoch, opt.start_epoch, False) else: print("supervised_data.src: ", len(supervised_data.src)) print("supervised_data.tgt: ", len(supervised_data.tgt)) print("supervised_data.trees: ", len(supervised_data.trees)) print("supervised_data.leafs: ", len(supervised_data.leafs)) xent_trainer = lib.Trainer(model, supervised_data, valid_data, metrics, dicts, optim, opt) if use_critic: start_time = time.time() # Supervised training. print("supervised training..") print("start_epoch: ", opt.start_epoch) xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time) # Create critic here to not affect random seed. critic, critic_optim = create_critic(checkpoint, dicts, opt) # Pretrain critic. print("pretrain critic...") if opt.critic_pretrain_epochs > 0: reinforce_trainer = lib.ReinforceTrainer(model, critic, supervised_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_reinforce, opt.start_reinforce + opt.critic_pretrain_epochs - 1, True, start_time) # Reinforce training. print("reinforce training...") reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data, test_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train(opt.start_reinforce + opt.critic_pretrain_epochs, opt.end_epoch, False, start_time) # Supervised training only. else: xent_trainer.train(opt.start_epoch, opt.end_epoch)
def main(): print("Start...") global opt opt = get_opt() # Set seed torch.manual_seed(opt.seed) np.random.seed(opt.seed) random.seed(opt.seed) opt.cuda = torch.cuda.is_available() and len(opt.gpus) if opt.save_dir and not os.path.exists(opt.save_dir): os.makedirs(opt.save_dir) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with -gpus 1" ) if opt.cuda: cuda.set_device(opt.gpus[0]) torch.cuda.manual_seed(opt.seed) dicts, supervised_data, rl_data, valid_data, test_data, DEV, EVAL = load_data( opt) print("Building model...") use_critic = opt.start_reinforce is not None print("use_critic: ", use_critic) print("has_baseline: ", opt.has_baseline) if not opt.has_baseline: assert opt.critic_pretrain_epochs == 0 if opt.load_from is None: model, optim = create_model(lib.Seq2SeqModel, dicts, dicts["tgt"].size()) checkpoint = None else: print("Loading from checkpoint at %s" % opt.load_from) checkpoint = torch.load( opt.load_from) #, map_location=lambda storage, loc: storage) model = checkpoint["model"] # config testing for attribute in ["predict_mask", "max_predict_length"]: model.opt.__dict__[attribute] = opt.__dict__[attribute] optim = checkpoint["optim"] optim.start_decay_at = opt.start_decay_at if optim.start_decay_at > opt.end_epoch: print("No decay!") opt.start_epoch = checkpoint["epoch"] + 1 print("model: ", model) print("optim: ", optim) # GPU. if opt.cuda: model.cuda(opt.gpus[0]) # Start reinforce training immediately. print("opt.start_reinforce: ", opt.start_reinforce) # Check if end_epoch is large enough. if use_critic: assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \ opt.end_epoch, "Please increase -end_epoch to perform pretraining!" nParams = sum([p.nelement() for p in model.parameters()]) print("* number of parameters: %d" % nParams) if opt.sent_reward == "cr": lib.RetReward.cr = code_retrieval.CrCritic() # Metrics. print("sent_reward: %s" % opt.sent_reward) metrics = {} metrics["xent_loss"] = lib.Loss.weighted_xent_loss metrics["critic_loss"] = lib.Loss.weighted_mse if opt.sent_reward == "bleu": metrics["sent_reward"] = { "train": lib.Reward.wrapped_sentence_bleu, "eval": lib.Reward.wrapped_sentence_bleu } else: metrics["sent_reward"] = { "train": lib.RetReward.retrieval_mrr_train, "eval": lib.RetReward.retrieval_mrr_eval } print("opt.eval: ", opt.eval) print("opt.eval_codenn: ", opt.eval_codenn) print("opt.eval_codenn_all: ", opt.eval_codenn_all) print("opt.collect_anno: ", opt.collect_anno) # Evaluate model if opt.eval: if opt.sent_reward == "cr" and (opt.eval_codenn or opt.eval_codenn_all): raise Exception( "Currently we do not support evaluating MRR on codenn!") if False: # On training set. if opt.sent_reward == "cr": metrics["sent_reward"][ "eval"] = lib.RetReward.retrieval_mrr_train #if opt.collect_anno: # metrics["sent_reward"] = {"train": None, "eval": None} evaluator = lib.Evaluator(model, metrics, dicts, opt) pred_file = opt.load_from.replace(".pt", ".train.pred") if opt.eval_codenn or opt.eval_codenn_all: raise Exception("Invalid eval_codenn!") print("train_data.src: ", len(supervised_data.src)) if opt.predict_mask: pred_file += ".masked" pred_file += ".metric%s" % opt.sent_reward evaluator.eval(supervised_data, pred_file) if True: # On validation set. if opt.sent_reward == "cr": metrics["sent_reward"][ "eval"] = lib.RetReward.retrieval_mrr_eval #if opt.collect_anno: # metrics["sent_reward"] = {"train": None, "eval": None} evaluator = lib.Evaluator(model, metrics, dicts, opt) pred_file = opt.load_from.replace(".pt", ".valid.pred") if opt.eval_codenn: pred_file = pred_file.replace("valid", "DEV") valid_data = DEV elif opt.eval_codenn_all: pred_file = pred_file.replace("valid", "DEV_all") print("* Please input valid data = DEV_all") print("valid_data.src: ", len(valid_data.src)) if opt.predict_mask: pred_file += ".masked" pred_file += ".metric%s" % opt.sent_reward evaluator.eval(valid_data, pred_file) if False: # On test set. if opt.sent_reward == "cr": metrics["sent_reward"][ "eval"] = lib.RetReward.retrieval_mrr_eval #if opt.collect_anno: # metrics["sent_reward"] = {"train": None, "eval": None} evaluator = lib.Evaluator(model, metrics, dicts, opt) pred_file = opt.load_from.replace(".pt", ".test.pred") if opt.eval_codenn: pred_file = pred_file.replace("test", "EVAL") test_data = EVAL elif opt.eval_codenn_all: pred_file = pred_file.replace("test", "EVAL_all") print("* Please input test data = EVAL_all") print("test_data.src: ", len(test_data.src)) if opt.predict_mask: pred_file += ".masked" pred_file += ".metric%s" % opt.sent_reward evaluator.eval(test_data, pred_file) else: print("supervised_data.src: ", len(supervised_data.src)) print("supervised_data.tgt: ", len(supervised_data.tgt)) xent_trainer = lib.Trainer(model, supervised_data, valid_data, metrics, dicts, optim, opt, DEV=DEV) if use_critic: start_time = time.time() # Supervised training. print("supervised training..") print("start_epoch: ", opt.start_epoch) xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time) if opt.sent_reward == "bleu": _valid_data = DEV else: _valid_data = valid_data if opt.has_baseline: # Create critic here to not affect random seed. critic, critic_optim = create_critic(checkpoint, dicts, opt) print("Building critic...") print("Critic: ", critic) print("Critic optim: ", critic_optim) # Pretrain critic. print("pretrain critic...") if opt.critic_pretrain_epochs > 0: reinforce_trainer = lib.ReinforceTrainer( model, critic, supervised_data, _valid_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train( opt.start_reinforce, opt.start_reinforce + opt.critic_pretrain_epochs - 1, True, start_time) else: print("NOTE: do not have a baseline model") critic, critic_optim = None, None # Reinforce training. print("reinforce training...") reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data, _valid_data, metrics, dicts, optim, critic_optim, opt) reinforce_trainer.train( opt.start_reinforce + opt.critic_pretrain_epochs, opt.end_epoch, False, start_time) else: # Supervised training only. Set opt.start_reinforce to None xent_trainer.train(opt.start_epoch, opt.end_epoch)
def main(): assert (opt.start_epoch <= opt.end_epoch), 'The start epoch should be <= End Epoch' log('Loading data from "%s"' % opt.data) dataset = torch.load(opt.data) supervised_data = lib.Dataset(dataset["train_xe"], opt.batch_size, opt.cuda, eval=False) bandit_data = lib.Dataset(dataset["train_pg"], opt.batch_size, opt.cuda, eval=False) sup_valid_data = lib.Dataset(dataset["sup_valid"], opt.eval_batch_size, opt.cuda, eval=True) bandit_valid_data = lib.Dataset(dataset["bandit_valid"], opt.eval_batch_size, opt.cuda, eval=True) test_data = lib.Dataset(dataset["test"], opt.eval_batch_size, opt.cuda, eval=True) dicts = dataset["dicts"] log(" * vocabulary size. source = %d; target = %d" % (dicts["src"].size(), dicts["tgt"].size())) log(" * number of XENT training sentences. %d" % len(dataset["train_xe"]["src"])) log(" * number of PG training sentences. %d" % len(dataset["train_pg"]["src"])) log(" * number of bandit valid sentences. %d" % len(dataset["bandit_valid"]["src"])) log(" * number of test sentences. %d" % len(dataset["test"]["src"])) log(" * maximum batch size. %d" % opt.batch_size) log("Building model...") use_critic = opt.start_reinforce is not None if opt.load_from is None: model, optim = create_model(lib.NMTModel, dicts, dicts["tgt"].size()) checkpoint = None else: log("Loading from checkpoint at %s" % opt.load_from) checkpoint = torch.load(opt.load_from) model = checkpoint["model"] optim = checkpoint["optim"] opt.start_epoch = checkpoint["epoch"] + 1 # GPU. if opt.cuda: model.cuda(opt.gpus[0]) # Start reinforce training immediately. if (opt.start_reinforce == -1): opt.start_decay_at = opt.start_epoch opt.start_reinforce = opt.start_epoch nParams = sum([p.nelement() for p in model.parameters()]) log("* number of parameters: %d" % nParams) # Metrics. metrics = {} metrics["nmt_loss"] = lib.Loss.weighted_xent_loss metrics["critic_loss"] = lib.Loss.weighted_mse log(" Simulated Feedback: charF score\nEvaluation: charF and Corpus BLEU") instance_charF = lib.Reward.charFEvaluator(dict_tgt=dicts["tgt"]) metrics["sent_reward"] = instance_charF.sentence_charF metrics["corp_reward"] = lib.Reward.corpus_bleu # Evaluate model on heldout dataset. if opt.eval: evaluator = lib.Evaluator(model, metrics, dicts, opt, trpro_logger) # On Bandit test data pred_file = opt.load_from.replace(".pt", ".test.pred") tgt_file = opt.load_from.replace(".pt", ".test.tgt") evaluator.eval(test_data, pred_file) evaluator.eval(test_data, pred_file=None, tgt_file=tgt_file) else: xent_trainer = lib.Trainer(model, supervised_data, sup_valid_data, metrics, dicts, optim, opt, trainprocess_logger=trpro_logger) if use_critic: start_time = time.time() # Supervised training: used when running pretrain+bandit together xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time) # Actor-Critic critic, critic_optim = create_critic(checkpoint, dicts, opt) reinforce_trainer = lib.ReinforceTrainer( model, critic, bandit_data, bandit_valid_data, test_data, metrics, dicts, optim, critic_optim, opt, trainprocess_logger=trpro_logger, stat_logger=stat_logger, samples_logger=samples_logger) reinforce_trainer.train(opt.start_reinforce, opt.end_epoch, start_time) if opt.use_bipnmt: stat_logger.close_file() samples_logger.close_file() else: # Supervised training only. xent_trainer.train(opt.start_epoch, opt.end_epoch) trpro_logger.close_file()