Пример #1
0
def ppo_pixel(log_name='ppo-dmlab-image', render=False):

    config = Config()
    log_dir = get_default_log_dir(ppo_pixel.__name__)

    config.num_workers = 8

    config.task_fn = lambda: Task(
        use_vision=True,
        use_pos=False,
        num_envs=config.num_workers,
        render=render,
    )
    config.eval_env = Task(
        use_vision=True,
        use_pos=False,
        num_envs=config.num_workers,
        log_dir=log_dir,
        render=render,
    )

    config.optimizer_fn = lambda params: torch.optim.RMSprop(
        params, lr=0.00025, alpha=0.99, eps=1e-5)
    config.network_fn = lambda: CategoricalActorCriticNet(
        config.state_dim, config.action_dim, NatureConvBody(in_channels=3))
    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()
    config.discount = 0.99
    config.use_gae = True
    config.gae_tau = 0.95
    config.entropy_weight = 0.01
    config.gradient_clip = 0.5
    config.rollout_length = 128
    config.optimization_epochs = 3
    config.mini_batch_size = 32 * 8
    config.ppo_ratio_clip = 0.1
    config.log_interval = 128 * 8
    config.logger = get_logger(tag=log_name)
    config.tag = log_name  # this name must be unique. Anything with the same name will be overwritten
    config.max_steps = int(2e7)
    run_steps(PPOAgent(config))