def enjoy_husky():
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    param_fname = os.path.join(args.reload_dir, 'param.json')
    with open(param_fname, 'r') as f:
        param = json.load(f)

    performance_fname = os.path.join(args.reload_dir, 'performance.p')

    if param['use_2D_env']:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2_2D.yaml')
        raw_env = Husky2DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     pos_interval=param['pos_interval'])
    else:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2.yaml')
        raw_env = Husky1DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     ob_space_range=[0.0, 40.0])

    # configure environment
    raw_env.reset_state_space(use_goal_info=param["use_goal_info"],
                              use_coords_and_orn=param["use_coords_and_orn"],
                              raycast_num=param["raycast_num"],
                              raycast_range=param["raycast_range"])
    raw_env.reset_goal_range(goal_range=param["goal_range"])

    env = Monitor(
        raw_env,
        logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    policy_fn = FeedbackPolicy

    enjoy(policy=policy_fn,
          env=env,
          total_timesteps=args.total_timesteps,
          base_path=args.reload_dir,
          ent_coef=param["ent_coef"],
          vf_coef=0.5,
          max_grad_norm=param['max_grad_norm'],
          gamma=param["gamma"],
          lam=param["lambda"],
          nsteps=param['nsteps'],
          ppo_minibatch_size=param['ppo_minibatch_size'],
          feedback_minibatch_size=param['feedback_minibatch_size'])
def enjoy(num_timesteps, seed):
    rank = MPI.COMM_WORLD.Get_rank()
    #sess = U.single_threaded_session()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()
    if args.meta != "":
        saver = tf.train.import_meta_graph(args.meta)
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    use_filler = not args.disable_filler
    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                               '..', 'configs', 'husky_navigate_enjoy.yaml')
    print(config_file)
    raw_env = HuskyNavigateEnv(gpu_count=args.gpu_count, config=config_file)

    env = Monitor(raw_env,
                  logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    policy_fn = MlpPolicy if args.mode == "SENSOR" else CnnPolicy2
    print(args.mode, (args.mode == "SENSOR"))

    ppo2.enjoy(policy=policy_fn,
               env=env,
               nsteps=600,
               nminibatches=4,
               lam=0.95,
               gamma=0.996,
               noptepochs=4,
               log_interval=1,
               ent_coef=.01,
               lr=lambda f: f * 2.5e-4,
               cliprange=lambda f: f * 0.2,
               total_timesteps=int(num_timesteps * 1.1),
               save_interval=5,
               reload_name=args.reload_name,
               sensor=(args.mode == "SENSOR"))
    '''
示例#3
0
def train(num_timesteps, seed):
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'configs',
                               'ant_climb.yaml')
    print(config_file)

    env = AntClimbEnv(config=config_file)
    
    env = Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    def mlp_policy_fn(name, sensor_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=sensor_space, ac_space=ac_space, hid_size=64, num_hid_layers=2)

    def fuse_policy_fn(name, ob_space, sensor_space, ac_space):
        return fuse_policy.FusePolicy(name=name, ob_space=ob_space, sensor_space=sensor_space, hid_size=64, num_hid_layers=2, ac_space=ac_space, save_per_acts=10000, session=sess)

    if args.mode == "SENSOR":
        pposgd_sensor.learn(env, mlp_policy_fn,
            max_timesteps=int(num_timesteps * 1.1 * 5),
            timesteps_per_actorbatch=6000,
            clip_param=0.2, entcoeff=0.00,
            optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
            gamma=0.99, lam=0.95,
            schedule='linear',
            save_per_acts=100,
            save_name="ant_ppo_mlp"
        )
        env.close()        
    else:
        pposgd_fuse.learn(env, fuse_policy_fn,
            max_timesteps=int(num_timesteps * 1.1),
            timesteps_per_actorbatch=2000,
            clip_param=0.2, entcoeff=0.01,
            optim_epochs=4, optim_stepsize=LEARNING_RATE, optim_batchsize=64,
            gamma=0.99, lam=0.95,
            schedule='linear',
            save_per_acts=50,
            save_name="ant_ppo_fuse",
            reload_name=args.reload_name
        )
        env.close()
def train(num_timesteps, seed):
    rank = MPI.COMM_WORLD.Get_rank()
    #sess = U.single_threaded_session()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()
    if args.meta != "":
        saver = tf.train.import_meta_graph(args.meta)
        saver.restore(sess,tf.train.latest_checkpoint('./'))

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    use_filler = not args.disable_filler
    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'configs',
                               'husky_navigate_enjoy.yaml')
    print(config_file)

    env = HuskyNavigateEnv(gpu_idx=args.gpu_idx, config = config_file)

    def policy_fn(name, ob_space, ac_space):
        if args.mode == "SENSOR":
            return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, hid_size=64, num_hid_layers=2)
        else:
            #return fuse_policy.FusePolicy(name=name, ob_space=ob_space, sensor_space=sensor_space, ac_space=ac_space, save_per_acts=10000, session=sess)
        #else:
            return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space, save_per_acts=10000, session=sess, kind='small')

    env = Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    pposgd_simple.enjoy(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1),
        timesteps_per_actorbatch=1024,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear',
        save_per_acts=50,
        sensor=args.mode=="SENSOR",
        reload_name=args.reload_name
    )
    env.close()
示例#5
0
def train(seed):
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if args.meta != "":
        saver = tf.train.import_meta_graph(args.meta)
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    use_filler = not args.disable_filler

    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'configs', 'config_husky.yaml')
    print(config_file)

    raw_env = HuskyNavigateEnv(gpu_idx=args.gpu_idx, config=config_file)
    step = raw_env.config['n_step']; episode = raw_env.config['n_episode']; iteration = raw_env.config['n_iter']
    elm_policy = raw_env.config['elm_active']
    num_timesteps = step*episode*iteration
    tpa = step*episode

    if args.mode == "SENSOR": #Blind Mode
        def policy_fn(name, ob_space, ac_space):
            return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, hid_size=128, num_hid_layers=4,
                                        elm_mode=elm_policy)
    elif args.mode == "DEPTH" or args.mode == "RGB": #Fusing sensor space with image space
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return fuse_policy.FusePolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    elif args.mode == "RESNET":
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return resnet_policy.ResPolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    elif args.mode == "ODE":
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return ode_policy.ODEPolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    else: #Using only image space
        def policy_fn(name, ob_space, ac_space):
            return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space, session=sess, kind='small')

    env = Monitor(raw_env, logger.get_dir() and
                  osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    args.reload_name = '/home/berk/PycharmProjects/Gibson_Exercise/gibson/utils/models/PPO_CNN_2020-11-26_500_50_137_150.model'
    print(args.reload_name)

    modes_camera = ["DEPTH", "RGB", "RESNET", "ODE"]
    if args.mode in modes_camera:
        pposgd_fuse.enjoy(env, policy_fn,
                          max_timesteps=int(num_timesteps * 1.1),
                          timesteps_per_actorbatch=tpa,
                          clip_param=0.2, entcoeff=0.03,
                          optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
                          gamma=0.99, lam=0.95,
                          schedule='linear',
                          save_name="PPO_{}_{}_{}_{}_{}".format(args.mode, datetime.date.today(), step, episode,
                                                                iteration),
                          save_per_acts=15,
                          reload_name=args.reload_name
                          )
    else:
        if args.mode == "SENSOR": sensor = True
        else: sensor = False
        pposgd_simple.enjoy(env, policy_fn,
                            max_timesteps=int(num_timesteps * 1.1),
                            timesteps_per_actorbatch=tpa,
                            clip_param=0.2, entcoeff=0.03,
                            optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
                            gamma=0.996, lam=0.95,
                            schedule='linear',
                            save_name="PPO_{}_{}_{}_{}_{}".format(args.mode, datetime.date.today(), step, episode,
                                                                  iteration),
                            save_per_acts=15,
                            sensor=sensor,
                            reload_name=args.reload_name
                            )
    env.close()
示例#6
0
def train():
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    if args.use_2D_env:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2_2D.yaml')
    else:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2.yaml')

    if args.use_2D_env:
        raw_env = Husky2DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     pos_interval=args.pos_interval,
                                     use_other_room=args.use_other_room)
    else:
        raw_env = Husky1DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     ob_space_range=[0.0, 40.0])

    # configure environment
    raw_env.reset_state_space(use_goal_info=args.use_goal_info,
                              use_coords_and_orn=args.use_coords_and_orn,
                              raycast_num=args.raycast_num,
                              raycast_range=args.raycast_range,
                              start_varying_range=args.start_varying_range)
    raw_env.reset_goal_range(goal_range=args.goal_range)

    env = Monitor(
        raw_env,
        logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    policy_fn = FeedbackPolicy

    base_dirname = os.path.join(currentdir, "simulation_and_analysis", "rslts")

    if not os.path.exists(base_dirname):
        os.mkdir(base_dirname)
    dir_name = "husky_ppo2_"
    if args.use_feedback:
        dir_name += "hr"
        if args.use_real_feedback:
            dir_name += "_real_feedback"

    elif args.use_rich_reward:
        dir_name += "rl_rich"
    else:
        dir_name += "rl_sparse"
    dir_name = addDateTime(dir_name)
    dir_name = os.path.join(base_dirname, dir_name)
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)

    hyperparams = {
        "seed": args.seed,
        "nsteps": args.nsteps,
        "total_timesteps": args.total_timesteps,
        "use_2D_env": args.use_2D_env,
        "use_other_room": args.use_other_room,
        "use_rich_reward": args.use_rich_reward,
        "use_multiple_starts": args.use_multiple_starts,
        "use_goal_info": args.use_goal_info,
        "use_coords_and_orn": args.use_coords_and_orn,
        "raycast_num": args.raycast_num,
        "raycast_range": args.raycast_range,
        "goal_range": args.goal_range,
        "start_varying_range": args.start_varying_range,
        "use_feedback": args.use_feedback,
        "use_real_feedback": args.use_real_feedback,
        "only_use_hr_until": args.only_use_hr_until,
        "trans_to_rl_in": args.trans_to_rl_in,
        "good_feedback_acc": args.good_feedback_acc,
        "bad_feedback_acc": args.bad_feedback_acc,
        "ppo_lr": args.ppo_lr,
        "ppo_batch_size": args.ppo_batch_size,
        "ppo_minibatch_size": args.ppo_minibatch_size,
        "init_rl_importance": args.init_rl_importance,
        "ent_coef": args.ent_coef,
        "gamma": args.gamma,
        "lambda": args.lam,
        "cliprange": args.cliprange,
        "max_grad_norm": args.max_grad_norm,
        "ppo_noptepochs": args.ppo_noptepochs,
        "feedback_lr": args.feedback_lr,
        "feedback_batch_size": args.feedback_batch_size,
        "feedback_minibatch_size": args.feedback_minibatch_size,
        "feedback_noptepochs": args.feedback_noptepochs,
        "min_feedback_buffer_size": args.min_feedback_buffer_size,
        "feedback_training_prop": args.feedback_training_prop,
        "feedback_training_new_prop": args.feedback_training_new_prop,
        "hf_loss_type": args.hf_loss_type,
        "hf_loss_param": args.hf_loss_param,
        "feedback_use_mixup": args.feedback_use_mixup,
        "pos_interval": args.pos_interval,
        "use_embedding": raw_env._use_embedding,
        "use_raycast": raw_env._use_raycast,
        "offline": raw_env.config['offline']
    }

    param_fname = os.path.join(dir_name, "param.json")
    with open(param_fname, "w") as f:
        json.dump(hyperparams, f, indent=4, sort_keys=True)

    hf_loss_param = [
        float(x) for x in args.hf_loss_param.split(",") if x != ""
    ]

    video_name = os.path.join(dir_name, "video.mp4")
    p_logging = p.startStateLogging(p.STATE_LOGGING_VIDEO_MP4, video_name)

    performance = learn(
        policy=policy_fn,
        env=env,
        raw_env=raw_env,
        use_2D_env=args.use_2D_env,
        use_other_room=args.use_other_room,
        use_multiple_starts=args.use_multiple_starts,
        use_rich_reward=args.use_rich_reward,
        use_feedback=args.use_feedback,
        use_real_feedback=args.use_real_feedback,
        only_use_hr_until=args.only_use_hr_until,
        trans_to_rl_in=args.trans_to_rl_in,
        nsteps=args.nsteps,
        total_timesteps=args.total_timesteps,
        ppo_lr=args.ppo_lr,
        cliprange=args.cliprange,
        max_grad_norm=args.max_grad_norm,
        ent_coef=args.ent_coef,
        gamma=args.gamma,
        lam=args.lam,
        ppo_noptepochs=args.ppo_noptepochs,
        ppo_batch_size=args.ppo_batch_size,
        ppo_minibatch_size=args.ppo_minibatch_size,
        init_rl_importance=args.init_rl_importance,
        feedback_lr=args.feedback_lr,
        feedback_noptepochs=args.feedback_noptepochs,
        feedback_batch_size=args.feedback_batch_size,
        feedback_minibatch_size=args.feedback_minibatch_size,
        min_feedback_buffer_size=args.min_feedback_buffer_size,
        feedback_training_prop=args.feedback_training_prop,
        feedback_training_new_prop=args.feedback_training_new_prop,
        hf_loss_type=args.hf_loss_type,
        hf_loss_param=hf_loss_param,
        feedback_use_mixup=args.feedback_use_mixup,
        good_feedback_acc=args.good_feedback_acc,
        bad_feedback_acc=args.bad_feedback_acc,
        log_interval=1,
        save_interval=5,
        reload_name=None,
        base_path=dir_name)

    p.stopStateLogging(p_logging)

    performance_fname = os.path.join(dir_name, "performance.p")
    with open(performance_fname, "wb") as f:
        pickle.dump(performance, f)

    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(performance["train_acc"], label="train_acc")
    plt.plot(performance["train_true_acc"], label="train_true_acc")
    plt.plot(performance["valid_acc"], label="valid_acc")
    plt.title("feedback acc: {}".format(args.good_feedback_acc))
    plt.legend()
    plt.ylim([0, 1])
    plt.savefig(os.path.join(dir_name, "acc.jpg"), dpi=300)
    plt.close('all')
示例#7
0
def train():
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])

    assert args.reload_dir is not None, "reload_dir cannot be None!"

    param_fname = os.path.join(args.reload_dir, 'param.json')
    with open(param_fname, 'r') as f:
        param = json.load(f)

    workerseed = param["seed"] + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    if param["use_2D_env"]:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2_2D.yaml')
        raw_env = Husky2DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     pos_interval=param["pos_interval"])
    else:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2.yaml')
        raw_env = Husky1DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     ob_space_range=[0.0, 40.0])

    # configure environment
    raw_env.reset_state_space(use_goal_info=param["use_goal_info"],
                              use_coords_and_orn=param["use_coords_and_orn"],
                              raycast_num=param["raycast_num"],
                              raycast_range=param["raycast_range"])
    raw_env.reset_goal_range(goal_range=param["goal_range"])

    env = Monitor(
        raw_env,
        logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    policy_fn = FeedbackPolicy

    print('here')
    base_dirname = os.path.join(currentdir, "simulation_and_analysis", "rslts")
    print(base_dirname)
    if not os.path.exists(base_dirname):
        os.mkdir(base_dirname)
    dir_name = "husky_ppo2_"
    if param["use_feedback"]:
        dir_name += "hr"
    elif param["use_rich_reward"]:
        dir_name += "rl_rich"
    else:
        dir_name += "rl_sparse"
    dir_name += "_reload"
    dir_name = addDateTime(dir_name)
    dir_name = os.path.join(base_dirname, dir_name)
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)

    hyperparams = {
        "seed": args.seed,
        "nsteps": param["nsteps"],
        "total_timesteps": args.total_timesteps,
        "use_2D_env": param["use_2D_env"],
        "use_rich_reward": param["use_rich_reward"],
        "use_multiple_starts": param["use_multiple_starts"],
        "use_goal_info": param["use_goal_info"],
        "use_coords_and_orn": param["use_coords_and_orn"],
        "raycast_num": param["raycast_num"],
        "raycast_range": param["raycast_range"],
        "goal_range": param["goal_range"],
        "use_feedback": args.use_feedback,
        "use_real_feedback": args.use_real_feedback,
        "trans_by_interpolate": args.trans_by_interpolate,
        "only_use_hr_until": args.only_use_hr_until,
        "trans_to_rl_in": args.trans_to_rl_in,
        "good_feedback_acc": param["good_feedback_acc"],
        "bad_feedback_acc": param["bad_feedback_acc"],
        "ppo_lr": args.ppo_lr,
        "ppo_batch_size": args.ppo_batch_size,
        "ppo_minibatch_size": param["ppo_minibatch_size"],
        "init_rl_importance": args.init_rl_importance,
        "ent_coef": args.ent_coef,
        "gamma": args.gamma,
        "lambda": args.lam,
        "cliprange": args.cliprange,
        "max_grad_norm": args.max_grad_norm,
        "ppo_noptepochs": args.ppo_noptepochs,
        "feedback_lr": param["feedback_lr"],
        "feedback_batch_size": param["feedback_batch_size"],
        "feedback_minibatch_size": param["feedback_minibatch_size"],
        "feedback_noptepochs": param["feedback_noptepochs"],
        "min_feedback_buffer_size": param["min_feedback_buffer_size"],
        "feedback_training_prop": param["feedback_training_prop"],
        "feedback_training_new_prop": param["feedback_training_new_prop"],
        "pos_interval": param["pos_interval"],
        "use_embedding": raw_env._use_embedding,
        "use_raycast": raw_env._use_raycast,
        "offline": raw_env.config['offline'],
        "reload_dir": args.reload_dir,
        "prev_total_timesteps": param["total_timesteps"]
    }

    param_fname = os.path.join(dir_name, "param.json")
    with open(param_fname, "w") as f:
        json.dump(hyperparams, f, indent=4, sort_keys=True)

    video_name = os.path.join(dir_name, "video.mp4")
    p_logging = p.startStateLogging(p.STATE_LOGGING_VIDEO_MP4, video_name)

    model_dir = os.path.join(args.reload_dir, 'models')
    max_model_iter = -1
    for fname in os.listdir(model_dir):
        if fname.isdigit():
            model_iter = int(fname)
            if model_iter > max_model_iter:
                max_model_iter = model_iter
                reload_name = os.path.join(model_dir, fname)

    performance = learn(
        policy=policy_fn,
        env=env,
        raw_env=raw_env,
        use_2D_env=param["use_2D_env"],
        use_multiple_starts=param["use_multiple_starts"],
        use_rich_reward=param["use_rich_reward"],
        use_feedback=args.use_feedback,
        use_real_feedback=args.use_real_feedback,
        trans_by_interpolate=args.trans_by_interpolate,
        only_use_hr_until=args.only_use_hr_until,
        trans_to_rl_in=args.trans_to_rl_in,
        nsteps=param["nsteps"],
        total_timesteps=args.total_timesteps,
        ppo_lr=args.ppo_lr,
        cliprange=args.cliprange,
        max_grad_norm=args.max_grad_norm,
        ent_coef=args.ent_coef,
        gamma=args.gamma,
        lam=args.lam,
        ppo_noptepochs=args.ppo_noptepochs,
        ppo_batch_size=args.ppo_batch_size,
        ppo_minibatch_size=param["ppo_minibatch_size"],
        init_rl_importance=args.init_rl_importance,
        feedback_lr=param["feedback_lr"],
        feedback_noptepochs=param["feedback_noptepochs"],
        feedback_batch_size=param["feedback_batch_size"],
        feedback_minibatch_size=param["feedback_minibatch_size"],
        min_feedback_buffer_size=param["min_feedback_buffer_size"],
        feedback_training_prop=param["feedback_training_prop"],
        feedback_training_new_prop=param["feedback_training_new_prop"],
        good_feedback_acc=param["good_feedback_acc"],
        bad_feedback_acc=param["bad_feedback_acc"],
        log_interval=1,
        save_interval=5,
        reload_name=reload_name,
        base_path=dir_name)

    p.stopStateLogging(p_logging)

    performance_fname = os.path.join(dir_name, "performance.p")
    with open(performance_fname, "wb") as f:
        pickle.dump(performance, f)
def train():
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    if args.use_2D_env:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2_2D.yaml')
    else:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   'configs', 'husky_space7_ppo2.yaml')

    if args.use_2D_env:
        raw_env = Husky2DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     pos_interval=args.pos_interval)
    else:
        raw_env = Husky1DNavigateEnv(gpu_idx=args.gpu_idx,
                                     config=config_file,
                                     ob_space_range=[0.0, 40.0])

    env = Monitor(
        raw_env,
        logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    base_dirname = os.path.join(currentdir, "simulation_and_analysis_dqn",
                                "rslts")

    if not os.path.exists(base_dirname):
        os.makedirs(base_dirname)
    dir_name = "husky_dqn_"
    if args.use_feedback:
        dir_name += "hr"
    elif args.use_rich_reward:
        dir_name += "rl_rich"
    else:
        dir_name += "rl_sparse"
    dir_name = addDateTime(dir_name)
    dir_name = os.path.join(base_dirname, dir_name)
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)

    hyperparams = {
        "seed": args.seed,
        # env
        "use_2D_env": args.use_2D_env,
        "use_rich_reward": args.use_rich_reward,
        "use_multiple_starts": args.use_multiple_starts,
        "total_timesteps": args.total_timesteps,
        "pos_interval": args.pos_interval,
        # hr
        "use_feedback": args.use_feedback,
        "use_real_feedback": args.use_real_feedback,
        "trans_by_interpolate": args.trans_by_interpolate,
        "only_use_hr_until": args.only_use_hr_until,
        "trans_to_rl_in": args.trans_to_rl_in,
        "good_feedback_acc": args.good_feedback_acc,
        "bad_feedback_acc": args.bad_feedback_acc,
        # dqn
        "exploration_fraction": args.exploration_fraction,
        "exploration_final_eps": args.exploration_final_eps,
        "lr": args.lr,
        "batch_size": args.batch_size,
        "dqn_epochs": args.dqn_epochs,
        "train_freq": args.train_freq,
        "target_network_update_freq": args.target_network_update_freq,
        "learning_starts": args.learning_starts,
        "param_noise": args.param_noise,
        "gamma": args.gamma,
        # hr training
        "feedback_lr": args.feedback_lr,
        "feedback_epochs": args.feedback_epochs,
        "feedback_batch_size": args.feedback_batch_size,
        "feedback_minibatch_size": args.feedback_minibatch_size,
        "min_feedback_buffer_size": args.min_feedback_buffer_size,
        "feedback_training_prop": args.feedback_training_prop,
        "feedback_training_new_prop": args.feedback_training_new_prop,
        # dqn replay buffer
        "buffer_size": args.buffer_size,
        "prioritized_replay": args.prioritized_replay,
        "prioritized_replay_alpha": args.prioritized_replay_alpha,
        "prioritized_replay_beta0": args.prioritized_replay_beta0,
        "prioritized_replay_beta_iters": args.prioritized_replay_beta_iters,
        "prioritized_replay_eps": args.prioritized_replay_eps,
        #
        "checkpoint_freq": args.checkpoint_freq,
        "use_embedding": raw_env._use_embedding,
        "use_raycast": raw_env._use_raycast,
        "offline": raw_env.config['offline']
    }

    print_freq = 5

    param_fname = os.path.join(dir_name, "param.json")
    with open(param_fname, "w") as f:
        json.dump(hyperparams, f, indent=4, sort_keys=True)

    video_name = os.path.join(dir_name, "video.mp4")
    p_logging = p.startStateLogging(p.STATE_LOGGING_VIDEO_MP4, video_name)

    act, performance = learn(  # env flags
        env,
        raw_env,
        use_2D_env=args.use_2D_env,
        use_multiple_starts=args.use_multiple_starts,
        use_rich_reward=args.use_rich_reward,
        total_timesteps=args.total_timesteps,
        # dqn
        exploration_fraction=args.exploration_fraction,
        exploration_final_eps=args.exploration_final_eps,
        # hr
        use_feedback=args.use_feedback,
        use_real_feedback=args.use_real_feedback,
        only_use_hr_until=args.only_use_hr_until,
        trans_to_rl_in=args.trans_to_rl_in,
        good_feedback_acc=args.good_feedback_acc,
        bad_feedback_acc=args.bad_feedback_acc,
        # dqn training
        lr=args.lr,
        batch_size=args.batch_size,
        dqn_epochs=args.dqn_epochs,
        train_freq=args.train_freq,
        target_network_update_freq=args.target_network_update_freq,
        learning_starts=args.learning_starts,
        param_noise=args.param_noise,
        gamma=args.gamma,
        # hr training
        feedback_lr=args.feedback_lr,
        feedback_epochs=args.feedback_epochs,
        feedback_batch_size=args.feedback_batch_size,
        feedback_minibatch_size=args.feedback_minibatch_size,
        min_feedback_buffer_size=args.min_feedback_buffer_size,
        feedback_training_prop=args.feedback_training_prop,
        feedback_training_new_prop=args.feedback_training_new_prop,
        # replay buffer
        buffer_size=args.buffer_size,
        prioritized_replay=args.prioritized_replay,
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        prioritized_replay_beta0=args.prioritized_replay_beta0,
        prioritized_replay_beta_iters=args.prioritized_replay_beta_iters,
        prioritized_replay_eps=args.prioritized_replay_eps,
        # rslts saving and others
        checkpoint_freq=args.checkpoint_freq,
        print_freq=print_freq,
        checkpoint_path=None,
        load_path=None,
        callback=None,
        seed=args.seed)

    p.stopStateLogging(p_logging)

    performance_fname = os.path.join(dir_name, "performance.p")
    with open(performance_fname, "wb") as f:
        pickle.dump(performance, f)
    act.save(os.path.join(dir_name, "cartpole_model.pkl"))