def ppo_pixel(log_name='ppo-dmlab-image', render=False): config = Config() log_dir = get_default_log_dir(ppo_pixel.__name__) config.num_workers = 8 config.task_fn = lambda: Task( use_vision=True, use_pos=False, num_envs=config.num_workers, render=render, ) config.eval_env = Task( use_vision=True, use_pos=False, num_envs=config.num_workers, log_dir=log_dir, render=render, ) config.optimizer_fn = lambda params: torch.optim.RMSprop( params, lr=0.00025, alpha=0.99, eps=1e-5) config.network_fn = lambda: CategoricalActorCriticNet( config.state_dim, config.action_dim, NatureConvBody(in_channels=3)) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.use_gae = True config.gae_tau = 0.95 config.entropy_weight = 0.01 config.gradient_clip = 0.5 config.rollout_length = 128 config.optimization_epochs = 3 config.mini_batch_size = 32 * 8 config.ppo_ratio_clip = 0.1 config.log_interval = 128 * 8 config.logger = get_logger(tag=log_name) config.tag = log_name # this name must be unique. Anything with the same name will be overwritten config.max_steps = int(2e7) run_steps(PPOAgent(config))