Exemplo n.º 1
0
def train_ddpg_all_with_lr_drop(num_frames, third=False):
    DDPGAgent.train(num_frames)
    Settings.LEARNING_RATE /= 10
    old_log_dir = Settings.FULL_LOG_DIR
    Settings.LOG_DIR = Settings.LOG_DIR + "_extended"
    Settings.setup_logging()
    DDPGAgent.resume_training(old_log_dir, num_frames)
    if third:
        Settings.TASK = "EVALUATE_DDPG"
        Settings.MODEL_NAME = Settings.FULL_LOG_DIR
        eval_agent = DDPGAgent.load(Settings.FULL_LOG_DIR)
        eval_agent.evaluate(Settings.NUM_EPISODES)
        Settings.TASK = "TRAIN_DDPG"
        Settings.LEARNING_RATE /= 10
        old_log_dir = Settings.FULL_LOG_DIR
        Settings.LOG_DIR = Settings.LOG_DIR + "2"
        Settings.setup_logging()
        DDPGAgent.resume_training(old_log_dir, num_frames)
    Settings.TASK = "EVALUATE_DDPG"
    Settings.MODEL_NAME = Settings.FULL_LOG_DIR
    eval_agent = DDPGAgent.load(Settings.FULL_LOG_DIR)
    eval_agent.evaluate(Settings.NUM_EPISODES)
Exemplo n.º 2
0
        for i, key in enumerate(keys):
            setattr(Settings, key, value_tuple[i])
        if not Settings.TEST_ROLLOUT_STATE and Settings.ST_TEST_ROLLOUTS != 2:
            continue
        if Settings.ROLLOUT_LENGTH == 1 and Settings.ST_TEST_ROLLOUTS != 2:
            continue
        if Settings.ST_TEST_ROLLOUTS > Settings.ROLLOUT_LENGTH:
            continue
        do_task()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", nargs='?', default=None)
    args = parser.parse_args()
    config_file = args.config
    if config_file is not None:
        Settings.load_from_file(config_file)
    Settings.setup_logging()
    merge_gym.register_environments()

    if Settings.SEED != "Random":
        np.random.seed(Settings.SEED)
        torch.manual_seed(Settings.SEED)
        random.seed(Settings.SEED)
        if Settings.CUDA:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

    do_task()