コード例 #1
0
def train(num_timesteps, seed):
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'configs',
                               'ant_navigate.yaml')
    print(config_file)

    env = AntNavigateEnv(config = config_file)
    
    def policy_fn(name, ob_space, ac_space):
        #return mlp_policy.MlpPolicy(name=name, ob_space=sensor_space, ac_space=ac_space, hid_size=64, num_hid_layers=2)
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space, save_per_acts=10000, session=sess, kind='small')


    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1 * 5),
        timesteps_per_actorbatch=6000,
        clip_param=0.2, entcoeff=0.00,
        optim_epochs=4, optim_stepsize=1e-4, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear',
        save_per_acts=500
    )
    env.close()
コード例 #2
0
from gibson.envs.ant_env import AntNavigateEnv, AntClimbEnv
from gibson.utils.play import play
import os

config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
                           'configs', 'play', 'play_ant_nonviz.yaml')
print(config_file)

if __name__ == '__main__':
    env = AntNavigateEnv(config=config_file)
    play(env, zoom=4)
コード例 #3
0
from gibson.envs.ant_env import AntNavigateEnv, AntClimbEnv
from gibson.utils.play import play
import os

timestep = 1.0 / (4 * 22)
frame_skip = 4
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
                           'configs', 'ant_navigate_nonviz.yaml')
print(config_file)

if __name__ == '__main__':
    env = AntNavigateEnv(is_discrete=True, config=config_file)
    play(env, zoom=4, fps=int(1.0 / (timestep * frame_skip)))