コード例 #1
0
def main():

    print('Loading data from "%s"' % opt.data)

    dataset = torch.load(opt.data)

    supervised_data = lib.Dataset(dataset["train_xe"], opt.batch_size, opt.cuda, eval=False)
    bandit_data = lib.Dataset(dataset["train_pg"], opt.batch_size, opt.cuda, eval=False)
    valid_data = lib.Dataset(dataset["valid"], opt.batch_size, opt.cuda, eval=True)
    test_data  = lib.Dataset(dataset["test"], opt.batch_size, opt.cuda, eval=True)

    dicts = dataset["dicts"]
    print(" * vocabulary size. source = %d; target = %d" %
          (dicts["src"].size(), dicts["tgt"].size()))
    print(" * number of XENT training sentences. %d" %
          len(dataset["train_xe"]["src"]))
    print(" * number of PG training sentences. %d" %
          len(dataset["train_pg"]["src"]))
    print(" * maximum batch size. %d" % opt.batch_size)
    print("Building model...")

    use_critic = opt.start_reinforce is not None

    if opt.load_from is None:
        model, optim = create_model(lib.NMTModel, dicts, dicts["tgt"].size())
        checkpoint = None
    else:
        print("Loading from checkpoint at %s" % opt.load_from)
        checkpoint = torch.load(opt.load_from)
        model = checkpoint["model"]
        optim = checkpoint["optim"]
        opt.start_epoch = checkpoint["epoch"] + 1

    # GPU.
    if opt.cuda:
        model.cuda(opt.gpus[0])

    # Start reinforce training immediately.
    if opt.start_reinforce == -1:
        opt.start_decay_at = opt.start_epoch
        opt.start_reinforce = opt.start_epoch

    # Check if end_epoch is large enough.
    if use_critic:
        assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \
            opt.end_epoch, "Please increase -end_epoch to perform pretraining!"

    nParams = sum([p.nelement() for p in model.parameters()])
    print("* number of parameters: %d" % nParams)

    # Metrics.
    metrics = {}
    metrics["nmt_loss"] = lib.Loss.weighted_xent_loss
    metrics["critic_loss"] = lib.Loss.weighted_mse
    metrics["sent_reward"] = lib.Reward.sentence_bleu
    metrics["corp_reward"] = lib.Reward.corpus_bleu
    if opt.pert_func is not None:
        opt.pert_func = lib.PertFunction(opt.pert_func, opt.pert_param)


    # Evaluate model on heldout dataset.
    if opt.eval:
        evaluator = lib.Evaluator(model, metrics, dicts, opt)
        # On validation set.
        pred_file = opt.load_from.replace(".pt", ".valid.pred")
        evaluator.eval(valid_data, pred_file)
        # On test set.
        pred_file = opt.load_from.replace(".pt", ".test.pred")
        evaluator.eval(test_data, pred_file)
    elif opt.eval_sample:
        opt.no_update = True
        critic, critic_optim = create_critic(checkpoint, dicts, opt)
        reinforce_trainer = lib.ReinforceTrainer(model, critic, bandit_data, test_data,
            metrics, dicts, optim, critic_optim, opt)
        reinforce_trainer.train(opt.start_epoch, opt.start_epoch, False)
    elif opt.sup_train_on_bandit:
        optim.set_lr(opt.reinforce_lr)
        xent_trainer = lib.Trainer(model, bandit_data, test_data, metrics, dicts, optim, opt)
        xent_trainer.train(opt.start_epoch, opt.start_epoch)
    else:
	print("theek hai")
        xent_trainer = lib.Trainer(model, supervised_data, valid_data, metrics, dicts, optim, opt)
        if use_critic:
            start_time = time.time()
            # Supervised training.
            xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time)
            # Create critic here to not affect random seed.
            critic, critic_optim = create_critic(checkpoint, dicts, opt)
            # Pretrain critic.
            if opt.critic_pretrain_epochs > 0:
                reinforce_trainer = lib.ReinforceTrainer(model, critic, supervised_data, test_data,
                    metrics, dicts, optim, critic_optim, opt)
                reinforce_trainer.train(opt.start_reinforce,
                    opt.start_reinforce + opt.critic_pretrain_epochs - 1, True, start_time)
            # Reinforce training.
            reinforce_trainer = lib.ReinforceTrainer(model, critic, bandit_data, test_data,
                    metrics, dicts, optim, critic_optim, opt)
            reinforce_trainer.train(opt.start_reinforce + opt.critic_pretrain_epochs, opt.end_epoch,
                False, start_time)
        # Supervised training only.
        else:
            xent_trainer.train(opt.start_epoch, opt.end_epoch)
コード例 #2
0
def main():
    print("Start...")
    global opt
    opt = get_opt()

    # Set seed
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)

    opt.cuda = len(opt.gpus)

    if opt.save_dir and not os.path.exists(opt.save_dir):
        os.makedirs(opt.save_dir)

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with -gpus 1")

    if opt.cuda:
        cuda.set_device(opt.gpus[0])
        torch.cuda.manual_seed(opt.seed)

    dicts, supervised_data, rl_data, valid_data, test_data, vis_data = load_data(opt)

    print("Building model...")

    use_critic = opt.start_reinforce is not None
    print("use_critic: ", use_critic)

    if opt.load_from is None:
        if opt.data_type == 'code':
            model, optim = create_model(lib.Tree2SeqModel, dicts, dicts["tgt"].size())
        elif opt.data_type == 'text':
            model, optim = create_model(lib.Seq2SeqModel, dicts, dicts["tgt"].size())
        elif opt.data_type == 'hybrid':
            model, optim = create_model(lib.Hybrid2SeqModel, dicts, dicts["tgt"].size())
        checkpoint = None
        print("model: ", model)
        print("optim: ", optim)
    else:
        print("Loading from checkpoint at %s" % opt.load_from)
        checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
        model = checkpoint["model"]
        optim = checkpoint["optim"]
        opt.start_epoch = checkpoint["epoch"] + 1

    # GPU.
    if opt.cuda:
        model.cuda(opt.gpus[0])

    # Start reinforce training immediately.
    print("opt.start_reinforce: ", opt.start_reinforce)
    if opt.start_reinforce == -1:
        opt.start_decay_at = opt.start_epoch
        opt.start_reinforce = opt.start_epoch

    # Check if end_epoch is large enough.
    if use_critic:
        assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \
               opt.end_epoch, "Please increase -end_epoch to perform pretraining!"

    nParams = sum([p.nelement() for p in model.parameters()])
    print("* number of parameters: %d" % nParams)

    # Metrics.
    metrics = {}
    metrics["xent_loss"] = lib.Loss.weighted_xent_loss
    metrics["critic_loss"] = lib.Loss.weighted_mse
    metrics["sent_reward"] = lib.Reward.sentence_bleu
    metrics["corp_reward"] = lib.Reward.corpus_bleu
    if opt.pert_func is not None:
        opt.pert_func = lib.PertFunction(opt.pert_func, opt.pert_param)

    print("opt.eval: ", opt.eval)
    print("opt.eval_sample: ", opt.eval_sample)

    # Evaluate model on heldout dataset.
    if opt.eval:
        evaluator = lib.Evaluator(model, metrics, dicts, opt)
        # On validation set.
        if opt.var_length:
            pred_file = opt.load_from.replace(".pt", ".valid.pred.var"+opt.var_type)
        else:
            pred_file = opt.load_from.replace(".pt", ".valid.pred")
        evaluator.eval(valid_data, pred_file)

        # On test set.
        if opt.var_length:
            pred_file = opt.load_from.replace(".pt", ".test.pred.var"+opt.var_type)
        else:
            pred_file = opt.load_from.replace(".pt", ".test.pred")
        evaluator.eval(test_data, pred_file)
    elif opt.eval_one:
        print("eval_one..")
        evaluator = lib.Evaluator(model, metrics, dicts, opt)
        # On test set.
        pred_file = opt.load_from.replace(".pt", ".test_one.pred")
        evaluator.eval(vis_data, pred_file)
    elif opt.eval_sample:
        opt.no_update = True
        critic, critic_optim = create_critic(checkpoint, dicts, opt)
        reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data, test_data,
                                                 metrics, dicts, optim, critic_optim, opt)
        reinforce_trainer.train(opt.start_epoch, opt.start_epoch, False)

    else:
        print("supervised_data.src: ", len(supervised_data.src))
        print("supervised_data.tgt: ", len(supervised_data.tgt))
        print("supervised_data.trees: ", len(supervised_data.trees))
        print("supervised_data.leafs: ", len(supervised_data.leafs))
        xent_trainer = lib.Trainer(model, supervised_data, valid_data, metrics, dicts, optim, opt)
        if use_critic:
            start_time = time.time()
            # Supervised training.
            print("supervised training..")
            print("start_epoch: ", opt.start_epoch)

            xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1, start_time)
            # Create critic here to not affect random seed.
            critic, critic_optim = create_critic(checkpoint, dicts, opt)
            # Pretrain critic.
            print("pretrain critic...")
            if opt.critic_pretrain_epochs > 0:
                reinforce_trainer = lib.ReinforceTrainer(model, critic, supervised_data, test_data, metrics, dicts, optim, critic_optim, opt)
                reinforce_trainer.train(opt.start_reinforce, opt.start_reinforce + opt.critic_pretrain_epochs - 1, True, start_time)
            # Reinforce training.
            print("reinforce training...")
            reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data, test_data, metrics, dicts, optim, critic_optim, opt)
            reinforce_trainer.train(opt.start_reinforce + opt.critic_pretrain_epochs, opt.end_epoch, False, start_time)

        # Supervised training only.
        else:
            xent_trainer.train(opt.start_epoch, opt.end_epoch)
コード例 #3
0
ファイル: a2c-train.py プロジェクト: sawan16/CoaCor
def main():
    print("Start...")
    global opt
    opt = get_opt()

    # Set seed
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)

    opt.cuda = torch.cuda.is_available() and len(opt.gpus)

    if opt.save_dir and not os.path.exists(opt.save_dir):
        os.makedirs(opt.save_dir)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with -gpus 1"
        )

    if opt.cuda:
        cuda.set_device(opt.gpus[0])
        torch.cuda.manual_seed(opt.seed)

    dicts, supervised_data, rl_data, valid_data, test_data, DEV, EVAL = load_data(
        opt)

    print("Building model...")

    use_critic = opt.start_reinforce is not None
    print("use_critic: ", use_critic)
    print("has_baseline: ", opt.has_baseline)
    if not opt.has_baseline:
        assert opt.critic_pretrain_epochs == 0

    if opt.load_from is None:
        model, optim = create_model(lib.Seq2SeqModel, dicts,
                                    dicts["tgt"].size())
        checkpoint = None

    else:
        print("Loading from checkpoint at %s" % opt.load_from)
        checkpoint = torch.load(
            opt.load_from)  #, map_location=lambda storage, loc: storage)
        model = checkpoint["model"]
        # config testing
        for attribute in ["predict_mask", "max_predict_length"]:
            model.opt.__dict__[attribute] = opt.__dict__[attribute]
        optim = checkpoint["optim"]
        optim.start_decay_at = opt.start_decay_at
        if optim.start_decay_at > opt.end_epoch:
            print("No decay!")
        opt.start_epoch = checkpoint["epoch"] + 1

    print("model: ", model)
    print("optim: ", optim)

    # GPU.
    if opt.cuda:
        model.cuda(opt.gpus[0])

    # Start reinforce training immediately.
    print("opt.start_reinforce: ", opt.start_reinforce)

    # Check if end_epoch is large enough.
    if use_critic:
        assert opt.start_epoch + opt.critic_pretrain_epochs - 1 <= \
               opt.end_epoch, "Please increase -end_epoch to perform pretraining!"

    nParams = sum([p.nelement() for p in model.parameters()])
    print("* number of parameters: %d" % nParams)

    if opt.sent_reward == "cr":
        lib.RetReward.cr = code_retrieval.CrCritic()

    # Metrics.
    print("sent_reward: %s" % opt.sent_reward)
    metrics = {}
    metrics["xent_loss"] = lib.Loss.weighted_xent_loss
    metrics["critic_loss"] = lib.Loss.weighted_mse
    if opt.sent_reward == "bleu":
        metrics["sent_reward"] = {
            "train": lib.Reward.wrapped_sentence_bleu,
            "eval": lib.Reward.wrapped_sentence_bleu
        }
    else:
        metrics["sent_reward"] = {
            "train": lib.RetReward.retrieval_mrr_train,
            "eval": lib.RetReward.retrieval_mrr_eval
        }

    print("opt.eval: ", opt.eval)
    print("opt.eval_codenn: ", opt.eval_codenn)
    print("opt.eval_codenn_all: ", opt.eval_codenn_all)
    print("opt.collect_anno: ", opt.collect_anno)

    # Evaluate model
    if opt.eval:
        if opt.sent_reward == "cr" and (opt.eval_codenn
                                        or opt.eval_codenn_all):
            raise Exception(
                "Currently we do not support evaluating MRR on codenn!")

        if False:
            # On training set.
            if opt.sent_reward == "cr":
                metrics["sent_reward"][
                    "eval"] = lib.RetReward.retrieval_mrr_train
            #if opt.collect_anno:
            #    metrics["sent_reward"] = {"train": None, "eval": None}

            evaluator = lib.Evaluator(model, metrics, dicts, opt)
            pred_file = opt.load_from.replace(".pt", ".train.pred")
            if opt.eval_codenn or opt.eval_codenn_all:
                raise Exception("Invalid eval_codenn!")
            print("train_data.src: ", len(supervised_data.src))
            if opt.predict_mask:
                pred_file += ".masked"
            pred_file += ".metric%s" % opt.sent_reward
            evaluator.eval(supervised_data, pred_file)

        if True:
            # On validation set.
            if opt.sent_reward == "cr":
                metrics["sent_reward"][
                    "eval"] = lib.RetReward.retrieval_mrr_eval
            #if opt.collect_anno:
            #    metrics["sent_reward"] = {"train": None, "eval": None}

            evaluator = lib.Evaluator(model, metrics, dicts, opt)
            pred_file = opt.load_from.replace(".pt", ".valid.pred")
            if opt.eval_codenn:
                pred_file = pred_file.replace("valid", "DEV")
                valid_data = DEV
            elif opt.eval_codenn_all:
                pred_file = pred_file.replace("valid", "DEV_all")
                print("* Please input valid data = DEV_all")
            print("valid_data.src: ", len(valid_data.src))
            if opt.predict_mask:
                pred_file += ".masked"
            pred_file += ".metric%s" % opt.sent_reward
            evaluator.eval(valid_data, pred_file)

        if False:
            # On test set.
            if opt.sent_reward == "cr":
                metrics["sent_reward"][
                    "eval"] = lib.RetReward.retrieval_mrr_eval
            #if opt.collect_anno:
            #    metrics["sent_reward"] = {"train": None, "eval": None}

            evaluator = lib.Evaluator(model, metrics, dicts, opt)
            pred_file = opt.load_from.replace(".pt", ".test.pred")
            if opt.eval_codenn:
                pred_file = pred_file.replace("test", "EVAL")
                test_data = EVAL
            elif opt.eval_codenn_all:
                pred_file = pred_file.replace("test", "EVAL_all")
                print("* Please input test data = EVAL_all")
            print("test_data.src: ", len(test_data.src))
            if opt.predict_mask:
                pred_file += ".masked"
            pred_file += ".metric%s" % opt.sent_reward
            evaluator.eval(test_data, pred_file)

    else:
        print("supervised_data.src: ", len(supervised_data.src))
        print("supervised_data.tgt: ", len(supervised_data.tgt))
        xent_trainer = lib.Trainer(model,
                                   supervised_data,
                                   valid_data,
                                   metrics,
                                   dicts,
                                   optim,
                                   opt,
                                   DEV=DEV)

        if use_critic:
            start_time = time.time()
            # Supervised training.
            print("supervised training..")
            print("start_epoch: ", opt.start_epoch)

            xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1,
                               start_time)

            if opt.sent_reward == "bleu":
                _valid_data = DEV
            else:
                _valid_data = valid_data

            if opt.has_baseline:
                # Create critic here to not affect random seed.
                critic, critic_optim = create_critic(checkpoint, dicts, opt)
                print("Building critic...")
                print("Critic: ", critic)
                print("Critic optim: ", critic_optim)

                # Pretrain critic.
                print("pretrain critic...")
                if opt.critic_pretrain_epochs > 0:
                    reinforce_trainer = lib.ReinforceTrainer(
                        model, critic, supervised_data, _valid_data, metrics,
                        dicts, optim, critic_optim, opt)
                    reinforce_trainer.train(
                        opt.start_reinforce,
                        opt.start_reinforce + opt.critic_pretrain_epochs - 1,
                        True, start_time)
            else:
                print("NOTE: do not have a baseline model")
                critic, critic_optim = None, None

            # Reinforce training.
            print("reinforce training...")
            reinforce_trainer = lib.ReinforceTrainer(model, critic, rl_data,
                                                     _valid_data, metrics,
                                                     dicts, optim,
                                                     critic_optim, opt)
            reinforce_trainer.train(
                opt.start_reinforce + opt.critic_pretrain_epochs,
                opt.end_epoch, False, start_time)

        else:  # Supervised training only. Set opt.start_reinforce to None
            xent_trainer.train(opt.start_epoch, opt.end_epoch)
コード例 #4
0
ファイル: train.py プロジェクト: heidelkin/BIPNMT
def main():
    assert (opt.start_epoch <=
            opt.end_epoch), 'The start epoch should be <= End Epoch'
    log('Loading data from "%s"' % opt.data)
    dataset = torch.load(opt.data)

    supervised_data = lib.Dataset(dataset["train_xe"],
                                  opt.batch_size,
                                  opt.cuda,
                                  eval=False)
    bandit_data = lib.Dataset(dataset["train_pg"],
                              opt.batch_size,
                              opt.cuda,
                              eval=False)

    sup_valid_data = lib.Dataset(dataset["sup_valid"],
                                 opt.eval_batch_size,
                                 opt.cuda,
                                 eval=True)
    bandit_valid_data = lib.Dataset(dataset["bandit_valid"],
                                    opt.eval_batch_size,
                                    opt.cuda,
                                    eval=True)
    test_data = lib.Dataset(dataset["test"],
                            opt.eval_batch_size,
                            opt.cuda,
                            eval=True)

    dicts = dataset["dicts"]
    log(" * vocabulary size. source = %d; target = %d" %
        (dicts["src"].size(), dicts["tgt"].size()))
    log(" * number of XENT training sentences. %d" %
        len(dataset["train_xe"]["src"]))
    log(" * number of PG training sentences. %d" %
        len(dataset["train_pg"]["src"]))
    log(" * number of bandit valid sentences. %d" %
        len(dataset["bandit_valid"]["src"]))
    log(" * number of  test sentences. %d" % len(dataset["test"]["src"]))
    log(" * maximum batch size. %d" % opt.batch_size)
    log("Building model...")

    use_critic = opt.start_reinforce is not None

    if opt.load_from is None:
        model, optim = create_model(lib.NMTModel, dicts, dicts["tgt"].size())
        checkpoint = None
    else:
        log("Loading from checkpoint at %s" % opt.load_from)
        checkpoint = torch.load(opt.load_from)
        model = checkpoint["model"]
        optim = checkpoint["optim"]
        opt.start_epoch = checkpoint["epoch"] + 1

    # GPU.
    if opt.cuda:
        model.cuda(opt.gpus[0])

    # Start reinforce training immediately.
    if (opt.start_reinforce == -1):
        opt.start_decay_at = opt.start_epoch
        opt.start_reinforce = opt.start_epoch

    nParams = sum([p.nelement() for p in model.parameters()])
    log("* number of parameters: %d" % nParams)

    # Metrics.
    metrics = {}
    metrics["nmt_loss"] = lib.Loss.weighted_xent_loss
    metrics["critic_loss"] = lib.Loss.weighted_mse
    log(" Simulated Feedback: charF score\nEvaluation: charF and Corpus BLEU")
    instance_charF = lib.Reward.charFEvaluator(dict_tgt=dicts["tgt"])
    metrics["sent_reward"] = instance_charF.sentence_charF
    metrics["corp_reward"] = lib.Reward.corpus_bleu

    # Evaluate model on heldout dataset.
    if opt.eval:
        evaluator = lib.Evaluator(model, metrics, dicts, opt, trpro_logger)

        # On Bandit test data
        pred_file = opt.load_from.replace(".pt", ".test.pred")
        tgt_file = opt.load_from.replace(".pt", ".test.tgt")
        evaluator.eval(test_data, pred_file)
        evaluator.eval(test_data, pred_file=None, tgt_file=tgt_file)

    else:
        xent_trainer = lib.Trainer(model,
                                   supervised_data,
                                   sup_valid_data,
                                   metrics,
                                   dicts,
                                   optim,
                                   opt,
                                   trainprocess_logger=trpro_logger)
        if use_critic:
            start_time = time.time()
            # Supervised training: used when running pretrain+bandit together
            xent_trainer.train(opt.start_epoch, opt.start_reinforce - 1,
                               start_time)
            # Actor-Critic
            critic, critic_optim = create_critic(checkpoint, dicts, opt)
            reinforce_trainer = lib.ReinforceTrainer(
                model,
                critic,
                bandit_data,
                bandit_valid_data,
                test_data,
                metrics,
                dicts,
                optim,
                critic_optim,
                opt,
                trainprocess_logger=trpro_logger,
                stat_logger=stat_logger,
                samples_logger=samples_logger)
            reinforce_trainer.train(opt.start_reinforce, opt.end_epoch,
                                    start_time)
            if opt.use_bipnmt:
                stat_logger.close_file()
                samples_logger.close_file()
        else:
            # Supervised training only.
            xent_trainer.train(opt.start_epoch, opt.end_epoch)

    trpro_logger.close_file()