Ejemplo n.º 1
0
def train(num_episodes, seed, space, evaluator, num_episodes_per_batch,
          reward_rule):

    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank(
    ) if seed is not None else None
    set_global_seeds(workerseed)

    # MAKE ENV_NAS
    structure = space['create_structure']['func'](
        **space['create_structure']['kwargs'])

    num_nodes = structure.num_nodes
    timesteps_per_actorbatch = num_nodes * num_episodes_per_batch
    num_timesteps = timesteps_per_actorbatch * num_episodes

    env = NasEnvEmb(space, evaluator, structure)

    def policy_fn(name, ob_space, ac_space):  #pylint: disable=W0613
        return lstm.LstmPolicy(name=name,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               num_units=32,
                               async_update=True)

    pposgd_async.learn(env,
                       policy_fn,
                       max_timesteps=num_timesteps,
                       timesteps_per_actorbatch=timesteps_per_actorbatch,
                       clip_param=0.2,
                       entcoeff=0.01,
                       optim_epochs=4,
                       optim_stepsize=1e-3,
                       optim_batchsize=15,
                       gamma=0.99,
                       lam=0.95,
                       schedule='linear',
                       reward_rule=reward_rule)
    env.close()
Ejemplo n.º 2
0
def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()

        def initialize_placeholders(nlstm=128, **kwargs):
            return np.zeros((args.num_env or 1, 2 * nlstm)), np.zeros((1))

        state, dones = initialize_placeholders(**extra_args)
        while True:
            actions, _, state, _ = model.step(obs, S=state, M=dones)
            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close()
Ejemplo n.º 3
0
def train(num_iter, seed, evaluator, num_episodes_per_iter):

    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed is not None else None
    set_global_seeds(workerseed)

    # MAKE ENV_NAS
    timesteps_per_episode = 10
    timesteps_per_actorbatch = timesteps_per_episode*num_episodes_per_iter
    num_timesteps = timesteps_per_actorbatch * num_iter

    env = MathEnv(evaluator)

    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return lstm_ph.LstmPolicy(name=name, ob_space=ob_space, ac_space=ac_space, num_units=64)

    pposgd_sync_ph.learn(env, policy_fn,
        max_timesteps=int(num_timesteps),
        timesteps_per_actorbatch=timesteps_per_actorbatch,
        clip_param=0.2,
        entcoeff=0.01, #0.01,
        optim_epochs=4,
        optim_stepsize=1e-3,
        optim_batchsize=10,
        gamma=0.99, # 0.99
        lam=0.95, # 0.95
        schedule='linear',
        reward_rule=reward_for_final_timestep
    )
    env.close()