def define_train(hparams, environment_spec, event_dir): """Define the training setup.""" policy_lambda = hparams.network if environment_spec == "stacked_pong": environment_spec = lambda: gym.make("PongNoFrameskip-v4") wrappers = hparams.in_graph_wrappers if hasattr( hparams, "in_graph_wrappers") else [] wrappers.append((tf_atari_wrappers.MaxAndSkipWrapper, {"skip": 4})) hparams.in_graph_wrappers = wrappers if isinstance(environment_spec, str): env_lambda = lambda: gym.make(environment_spec) else: env_lambda = environment_spec batch_env = utils.batch_env_factory(env_lambda, hparams, num_agents=hparams.num_agents) policy_factory = functools.partial(policy_lambda, batch_env.action_space, hparams) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): memory, collect_summary = collect.define_collect( policy_factory, batch_env, hparams, eval_phase=False, on_simulated=hparams.simulated_environment) ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) summary = tf.summary.merge([collect_summary, ppo_summary]) with tf.variable_scope("eval", reuse=tf.AUTO_REUSE): eval_env_lambda = env_lambda if event_dir and hparams.video_during_eval: # Some environments reset environments automatically, when reached done # state. For them we shall record only every second episode. d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1 eval_env_lambda = lambda: gym.wrappers.Monitor( # pylint: disable=g-long-lambda env_lambda(), event_dir, video_callable=lambda i: i % d == 0) eval_env_lambda = ( lambda: utils.EvalVideoWrapper(eval_env_lambda())) eval_batch_env = utils.batch_env_factory( eval_env_lambda, hparams, num_agents=hparams.num_eval_agents, xvfb=hparams.video_during_eval) # TODO(blazej0): correct to the version below. corrected = True eval_summary = tf.no_op() if corrected: _, eval_summary = collect.define_collect(policy_factory, eval_batch_env, hparams, eval_phase=True) return summary, eval_summary
def define_train(hparams, environment_spec, event_dir): """Define the training setup.""" if isinstance(environment_spec, str): env_lambda = lambda: gym.make(environment_spec) else: env_lambda = environment_spec policy_lambda = hparams.network env = env_lambda() action_space = env.action_space batch_env = utils.define_batch_env(env_lambda, hparams.num_agents) policy_factory = tf.make_template( "network", functools.partial(policy_lambda, action_space, hparams)) with tf.variable_scope("train"): memory, collect_summary = collect.define_collect(policy_factory, batch_env, hparams, eval_phase=False) ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) summary = tf.summary.merge([collect_summary, ppo_summary]) with tf.variable_scope("eval"): eval_env_lambda = env_lambda if event_dir and hparams.video_during_eval: # Some environments reset environments automatically, when reached done # state. For them we shall record only every second episode. d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1 eval_env_lambda = lambda: gym.wrappers.Monitor( # pylint: disable=g-long-lambda env_lambda(), event_dir, video_callable=lambda i: i % d == 0) wrapped_eval_env_lambda = lambda: utils.EvalVideoWrapper( eval_env_lambda()) _, eval_summary = collect.define_collect( policy_factory, utils.define_batch_env(wrapped_eval_env_lambda, hparams.num_eval_agents, xvfb=hparams.video_during_eval), hparams, eval_phase=True) return summary, eval_summary