def define_train(hparams, environment_spec, event_dir):
    """Define the training setup."""
    policy_lambda = hparams.network

    if environment_spec == "stacked_pong":
        environment_spec = lambda: gym.make("PongNoFrameskip-v4")
        wrappers = hparams.in_graph_wrappers if hasattr(
            hparams, "in_graph_wrappers") else []
        wrappers.append((tf_atari_wrappers.MaxAndSkipWrapper, {"skip": 4}))
        hparams.in_graph_wrappers = wrappers
    if isinstance(environment_spec, str):
        env_lambda = lambda: gym.make(environment_spec)
    else:
        env_lambda = environment_spec

    batch_env = utils.batch_env_factory(env_lambda,
                                        hparams,
                                        num_agents=hparams.num_agents)

    policy_factory = functools.partial(policy_lambda, batch_env.action_space,
                                       hparams)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        memory, collect_summary = collect.define_collect(
            policy_factory,
            batch_env,
            hparams,
            eval_phase=False,
            on_simulated=hparams.simulated_environment)
        ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
        summary = tf.summary.merge([collect_summary, ppo_summary])

    with tf.variable_scope("eval", reuse=tf.AUTO_REUSE):
        eval_env_lambda = env_lambda
        if event_dir and hparams.video_during_eval:
            # Some environments reset environments automatically, when reached done
            # state. For them we shall record only every second episode.
            d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
            eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
                env_lambda(),
                event_dir,
                video_callable=lambda i: i % d == 0)
            eval_env_lambda = (
                lambda: utils.EvalVideoWrapper(eval_env_lambda()))
        eval_batch_env = utils.batch_env_factory(
            eval_env_lambda,
            hparams,
            num_agents=hparams.num_eval_agents,
            xvfb=hparams.video_during_eval)

        # TODO(blazej0): correct to the version below.
        corrected = True
        eval_summary = tf.no_op()
        if corrected:
            _, eval_summary = collect.define_collect(policy_factory,
                                                     eval_batch_env,
                                                     hparams,
                                                     eval_phase=True)
    return summary, eval_summary
Beispiel #2
0
def define_train(hparams, environment_spec, event_dir):
    """Define the training setup."""
    if isinstance(environment_spec, str):
        env_lambda = lambda: gym.make(environment_spec)
    else:
        env_lambda = environment_spec
    policy_lambda = hparams.network
    env = env_lambda()
    action_space = env.action_space

    batch_env = utils.define_batch_env(env_lambda, hparams.num_agents)

    policy_factory = tf.make_template(
        "network", functools.partial(policy_lambda, action_space, hparams))

    with tf.variable_scope("train"):
        memory, collect_summary = collect.define_collect(policy_factory,
                                                         batch_env,
                                                         hparams,
                                                         eval_phase=False)
    ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
    summary = tf.summary.merge([collect_summary, ppo_summary])

    with tf.variable_scope("eval"):
        eval_env_lambda = env_lambda
        if event_dir and hparams.video_during_eval:
            # Some environments reset environments automatically, when reached done
            # state. For them we shall record only every second episode.
            d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
            eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
                env_lambda(),
                event_dir,
                video_callable=lambda i: i % d == 0)
        wrapped_eval_env_lambda = lambda: utils.EvalVideoWrapper(
            eval_env_lambda())
        _, eval_summary = collect.define_collect(
            policy_factory,
            utils.define_batch_env(wrapped_eval_env_lambda,
                                   hparams.num_eval_agents,
                                   xvfb=hparams.video_during_eval),
            hparams,
            eval_phase=True)
    return summary, eval_summary