Exemple #1
0
 def create_out_path(self, args):
     if args.model_path is not None:
         out_path = os.path.join(args.model_path, "eval_from_loaded_model")
         self.out_path = os.path.join(
             out_path,
             "{}".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
     else:
         out_path = '{}_{}_{}_layers_{}_emb_{}_hidden_{}_pdrop_{}_gradclip_{}_bs_{}_lr_{}'.format(
             args.dataset, args.task, args.model, args.num_layers,
             args.emb_size, args.hidden_size, args.p_drop, args.grad_clip,
             args.bs, args.lr)
         if args.task == 'policy':
             out_path = out_path + '_cond-answer_{}'.format(
                 args.condition_answer)
         self.out_path = os.path.join(
             args.out_path, out_path,
             "{}".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
     if not os.path.exists(self.out_path):
         os.makedirs(self.out_path)
     out_file_log = os.path.join(self.out_path, 'training_log.log')
     self.logger = create_logger(out_file_log)
     self.out_csv = os.path.join(self.out_path, 'train_history.csv')
     self.out_lm_metrics = os.path.join(
         self.out_path, 'lm_metrics_sf{}.csv'.format(args.bleu_sf))
     self.model_path = os.path.join(self.out_path, 'model.pt')
     self.logger.info("hparams: {}".format(vars(args)))
     self.logger.info('train dataset length: {}'.format(
         self.train_dataset.__len__()))
     self.logger.info("val dataset length: {}".format(len(
         self.val_dataset)))
     if self.dataset_name == "vqa":
         self.logger.info("number of filtered entries:{}".format(
             len(self.train_dataset.filtered_entries)))
     self.logger.info('number of tokens: {}'.format(
         self.train_dataset.len_vocab))
     self._save_hparams(args, self.out_path)
Exemple #2
0
        os.makedirs(output_path)
    if not args.cv:
        for inp, tar in train_dataset.take(1):
            seq_len = tf.shape(inp)[1].numpy()
            num_features = tf.shape(inp)[-1].numpy()
            output_size = tf.shape(tar)[-1].numpy()
    else:
        for inp, tar in list_train_dataset[0].take(1):
            seq_len = tf.shape(inp)[1].numpy()
            num_features = tf.shape(inp)[-1].numpy()
            output_size = tf.shape(tar)[-1].numpy()

    # -------------------- create logger and checkpoint saver ----------------------------------------------------------------------------------------------------

    out_file_log = output_path + '/' + 'training_log.log'
    logger = create_logger(out_file_log=out_file_log)
    #  creating the checkpoint manager:
    checkpoint_path = os.path.join(output_path, "checkpoints")
    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    # -------------------- Build the RNN model -----------------------------------------------------------------------------------------
    model = build_LSTM_for_regression(shape_input_1=seq_len,
                                      shape_input_2=num_features,
                                      shape_output=output_size,
                                      rnn_units=rnn_units,
                                      dropout_rate=args.p_drop,
                                      rnn_drop_rate=args.rnn_drop,
                                      training=True)
    if not args.cv:
        train_LSTM(model=model,
Exemple #3
0
def main(args):
    type_folder = "train" if args.pretrain == 0 else "pretrain"
    output_path = os.path.join(
        args.out_path, "experiments", type_folder,
        "{}".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    out_file_log = os.path.join(output_path, 'RL_training_log.log')
    out_policy_file = os.path.join(output_path, 'model.pth')

    logger = create_logger(out_file_log, level=args.logger_level)
    truncated = "basic" if args.pretrained_path is None else "truncated"
    writer = SummaryWriter(log_dir=os.path.join(
        output_path,
        "runs_{}_{}_{}_{}_{}".format(truncated, args.max_len, args.debug,
                                     args.entropy_coeff, args.num_truncated)))

    env = ClevrEnv(args.data_path,
                   args.max_len,
                   reward_type=args.reward,
                   mode="train",
                   debug=args.debug,
                   num_questions=args.num_questions)

    pretrained_lm = None
    if args.pretrained_path is not None:
        pretrained_lm = torch.load(args.pretrained_path)
        pretrained_lm.eval()

    models = {"lstm": PolicyLSTMBatch, "lstm_word": PolicyLSTMWordBatch}

    generic_kwargs = {
        "pretrained_lm": pretrained_lm,
        "pretrain": args.pretrain,
        "word_emb_size": args.word_emb_size,
        "hidden_size": args.hidden_size,
        "kernel_size": args.conv_kernel,
        "stride": args.stride,
        "num_filters": args.num_filters,
        "num_truncated": args.num_truncated,
        "writer": writer
    }

    ppo_kwargs = {
        "policy": models[args.model],
        "env": env,
        "gamma": args.gamma,
        "K_epochs": args.K_epochs,
        "update_every": args.update_every,
        "entropy_coeff": args.entropy_coeff,
        "eps_clip": args.eps_clip
    }
    reinforce_kwargs = {
        "env": env,
        "policy": models[args.model],
        "gamma": args.gamma,
        "lr": args.lr,
        "word_emb_size": args.word_emb_size,
        "hidden_size": args.hidden_size
    }
    algo_kwargs = {"PPO": ppo_kwargs, "REINFORCE": reinforce_kwargs}
    kwargs = {**algo_kwargs[args.agent], **generic_kwargs}

    agents = {"PPO": PPO, "REINFORCE": REINFORCE}

    agent = agents[args.agent](**kwargs)

    agent.learn(log_interval=args.log_interval,
                num_episodes=args.num_episodes_train)
    agent.save(out_policy_file)
    agent.test(log_interval=args.log_interval,
               num_episodes=args.num_episodes_test)
Exemple #4
0
def run(args):
    # check consistency hparams
    if args.reward == "vqa":
        assert args.condition_answer is not None, "VQA task should be conditioned on the answer"

    # create out_folder, config file, logger, writer
    output_path = get_output_path(args)
    conf_file = os.path.join(output_path, 'conf.ini')
    out_file_log = os.path.join(output_path, 'RL_training_log.log')
    out_policy_file = os.path.join(output_path, 'model.pth')
    cmd_file = os.path.join(output_path, 'cmd.txt')
    create_config_file(conf_file, args)
    create_cmd_file(cmd_file)
    logger = create_logger(out_file_log, level=args.logger_level)
    writer = SummaryWriter(log_dir=os.path.join(output_path, "runs"))

    device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.
                          is_available() else "cpu")
    logger.info("this experiment is in {}".format(output_path))
    # log hparams:
    log_hparams(logger, args)

    # upload env & pretrained lm, policy network.
    env, test_envs = get_rl_env(args, device)
    pretrained_lm = get_pretrained_lm(args, env, device)
    # dataset statistics
    logger.info('-' * 20 + 'Dataset statistics' + '-' * 20)
    logger.info("number of training questions:{}".format(len(env.dataset)))
    logger.info("vocab size:{}".format(len(env.dataset.vocab_questions)))

    models = {"lstm": PolicyLSTMBatch}
    # creating the policy model.
    policy = models[args.model](env.dataset.len_vocab,
                                args.word_emb_size,
                                args.hidden_size,
                                kernel_size=args.conv_kernel,
                                stride=args.stride,
                                num_filters=args.num_filters,
                                fusion=args.fusion,
                                env=env,
                                condition_answer=args.condition_answer,
                                device=device,
                                attention_dim=args.attention_dim)
    if args.policy_path is not None:
        pretrained = torch.load(args.policy_path, map_location=device)
        if pretrained.__class__ != OrderedDict:
            if pretrained.__class__ == dict:
                pretrained = pretrained["model_state_dict"]
            else:
                pretrained = pretrained.state_dict()
        policy.load_state_dict(pretrained, strict=False)
        policy.device = device
    optimizer, scheduler = get_optimizer(policy, args)
    agent = get_agent(pretrained_lm=pretrained_lm,
                      writer=writer,
                      output_path=output_path,
                      env=env,
                      test_envs=test_envs,
                      policy=policy,
                      optimizer=optimizer,
                      args_=args)

    # start training
    if args.resume_training is not None:
        epoch, loss = agent.load_ckpt(
            os.path.join(args.resume_training, "checkpoints"))
        logger.info(
            'resume training after {} episodes... current loss: {}'.format(
                epoch, loss))
        agent.start_episode = epoch + 1
    if args.num_episodes_train > 0:  # trick to avoid a bug inside the agent.learn function in case of no training.
        agent.learn(num_episodes=args.num_episodes_train)
        agent.save(out_policy_file)
    else:
        logger.info("skipping training...")

    # start evaluation
    logger.info(
        '---------------------------------- STARTING EVALUATION --------------------------------------------------------------------------'
    )
    #eval_modes = ["greedy", "sampling", "sampling_ranking_lm"]
    for mode in args.eval_modes:
        logger.info(
            "----------------------------- Starting evaluation for {} action selection -------------------------"
            .format(mode))
        agent.test(num_episodes=args.num_episodes_test,
                   test_mode=mode,
                   test_seed=args.test_seed)
    # write to csv test scalar metrics:
    agent.compute_write_all_metrics(output_path=output_path, logger=logger)
    logger.info(
        '------------------------------------ DONE ---------------------------------------------------------------'
    )
    return agent