def train(num_timesteps, seed):
    rank = MPI.COMM_WORLD.Get_rank()
    #sess = U.single_threaded_session()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()
    if args.meta != "":
        saver = tf.train.import_meta_graph(args.meta)
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    use_filler = not args.disable_filler

    if args.mode == "SENSOR":
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   '..', 'configs',
                                   'husky_navigate_nonviz_train.yaml')
    else:
        config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   '..', 'configs',
                                   'husky_navigate_rgb_train.yaml')
    print(config_file)

    raw_env = HuskyNavigateEnv(gpu_idx=args.gpu_idx, config=config_file)

    env = Monitor(raw_env,
                  logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    gym.logger.setLevel(logging.WARN)

    policy_fn = MlpPolicy if args.mode == "SENSOR" else CnnPolicy2
    args.reload_name = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), '..', '..', 'gibson',
        'utils', 'models', '00100')

    ppo2.learn(policy=policy_fn,
               env=env,
               nsteps=500,
               nminibatches=4,
               lam=0.95,
               gamma=0.99,
               noptepochs=4,
               log_interval=1,
               ent_coef=.1,
               lr=lambda f: f * 2e-4,
               cliprange=lambda f: f * 0.2,
               total_timesteps=int(num_timesteps * 1.1),
               save_interval=10,
               sensor=args.mode == "SENSOR",
               reload_name=args.reload_name)

    env.close()
示例#2
0
def train(seed):
    rank = MPI.COMM_WORLD.Get_rank()
    sess = utils.make_gpu_session(args.num_gpu)
    sess.__enter__()

    if args.meta != "":
        saver = tf.train.import_meta_graph(args.meta)
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    use_filler = not args.disable_filler

    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'configs', 'config_husky.yaml')
    print(config_file)

    raw_env = HuskyNavigateEnv(gpu_idx=args.gpu_idx, config=config_file)
    step = raw_env.config['n_step']; episode = raw_env.config['n_episode']; iteration = raw_env.config['n_iter']
    elm_policy = raw_env.config['elm_active']
    num_timesteps = step*episode*iteration
    tpa = step*episode

    if args.mode == "SENSOR": #Blind Mode
        def policy_fn(name, ob_space, ac_space):
            return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, hid_size=128, num_hid_layers=4,
                                        elm_mode=elm_policy)
    elif args.mode == "DEPTH" or args.mode == "RGB": #Fusing sensor space with image space
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return fuse_policy.FusePolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    elif args.mode == "RESNET":
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return resnet_policy.ResPolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    elif args.mode == "ODE":
        def policy_fn(name, ob_space, sensor_space, ac_space):
            return ode_policy.ODEPolicy(name=name, ob_space=ob_space, sensor_space = sensor_space, ac_space=ac_space,
                                          save_per_acts=10000, hid_size=128, num_hid_layers=4, session=sess, elm_mode=elm_policy)

    else: #Using only image space
        def policy_fn(name, ob_space, ac_space):
            return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space, session=sess, kind='small')

    env = Monitor(raw_env, logger.get_dir() and
                  osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    args.reload_name = '/home/berk/PycharmProjects/Gibson_Exercise/gibson/utils/models/PPO_CNN_2020-11-26_500_50_137_150.model'
    print(args.reload_name)

    modes_camera = ["DEPTH", "RGB", "RESNET", "ODE"]
    if args.mode in modes_camera:
        pposgd_fuse.enjoy(env, policy_fn,
                          max_timesteps=int(num_timesteps * 1.1),
                          timesteps_per_actorbatch=tpa,
                          clip_param=0.2, entcoeff=0.03,
                          optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
                          gamma=0.99, lam=0.95,
                          schedule='linear',
                          save_name="PPO_{}_{}_{}_{}_{}".format(args.mode, datetime.date.today(), step, episode,
                                                                iteration),
                          save_per_acts=15,
                          reload_name=args.reload_name
                          )
    else:
        if args.mode == "SENSOR": sensor = True
        else: sensor = False
        pposgd_simple.enjoy(env, policy_fn,
                            max_timesteps=int(num_timesteps * 1.1),
                            timesteps_per_actorbatch=tpa,
                            clip_param=0.2, entcoeff=0.03,
                            optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
                            gamma=0.996, lam=0.95,
                            schedule='linear',
                            save_name="PPO_{}_{}_{}_{}_{}".format(args.mode, datetime.date.today(), step, episode,
                                                                  iteration),
                            save_per_acts=15,
                            sensor=sensor,
                            reload_name=args.reload_name
                            )
    env.close()