def define_train(hparams):
    """Define the training setup."""
    train_hparams = copy.copy(hparams)
    train_hparams.add_hparam("eval_phase", False)
    train_hparams.add_hparam("policy_to_actions_lambda",
                             lambda policy: policy.sample())

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        train_env = hparams.env_fn(in_graph=True)
        memory, collect_summary, train_initialization = (
            collect.define_collect(train_env, train_hparams, "ppo_train"))
        ppo_summary = ppo.define_ppo_epoch(memory, hparams,
                                           train_env.action_space)
        train_summary = tf.summary.merge([collect_summary, ppo_summary])

        if hparams.eval_every_epochs:
            eval_hparams = copy.copy(hparams)
            eval_hparams.add_hparam("eval_phase", True)
            eval_hparams.add_hparam("policy_to_actions_lambda",
                                    lambda policy: policy.mode())
            eval_env = hparams.eval_env_fn(in_graph=True)
            eval_hparams.num_agents = hparams.num_eval_agents

            _, eval_collect_summary, eval_initialization = (
                collect.define_collect(eval_env, eval_hparams, "ppo_eval"))
            return train_summary, eval_collect_summary, (train_initialization,
                                                         eval_initialization)
        else:
            return train_summary, None, (train_initialization, )
Exemplo n.º 2
0
def _define_train(train_env,
                  ppo_hparams,
                  eval_env_fn=None,
                  sampling_temp=1.0,
                  **collect_kwargs):
    """Define the training setup."""
    memory, collect_summary, train_initialization = (_define_collect(
        train_env,
        ppo_hparams,
        "ppo_train",
        eval_phase=False,
        sampling_temp=sampling_temp,
        **collect_kwargs))
    ppo_summary = ppo.define_ppo_epoch(memory, ppo_hparams,
                                       train_env.action_space,
                                       train_env.batch_size)
    train_summary = tf.summary.merge([collect_summary, ppo_summary])

    if ppo_hparams.eval_every_epochs:
        # TODO(koz4k): Do we need this at all?
        assert eval_env_fn is not None
        eval_env = eval_env_fn(in_graph=True)
        (_, eval_collect_summary,
         eval_initialization) = (_define_collect(eval_env,
                                                 ppo_hparams,
                                                 "ppo_eval",
                                                 eval_phase=True,
                                                 sampling_temp=0.0,
                                                 **collect_kwargs))
        return (train_summary, eval_collect_summary, (train_initialization,
                                                      eval_initialization))
    else:
        return (train_summary, None, (train_initialization, ))
Exemplo n.º 3
0
def _define_train(train_env, ppo_hparams, eval_env_fn=None, **collect_kwargs):
  """Define the training setup."""
  memory, collect_summary, train_initialization = (
      _define_collect(
          train_env,
          ppo_hparams,
          "ppo_train",
          eval_phase=False,
          policy_to_actions_lambda=(lambda policy: policy.sample()),
          **collect_kwargs))
  ppo_summary = ppo.define_ppo_epoch(
      memory, ppo_hparams, train_env.action_space, train_env.batch_size)
  train_summary = tf.summary.merge([collect_summary, ppo_summary])

  if ppo_hparams.eval_every_epochs:
    assert eval_env_fn is not None
    eval_env = eval_env_fn(in_graph=True)
    (_, eval_collect_summary, eval_initialization) = (
        _define_collect(
            eval_env,
            ppo_hparams,
            "ppo_eval",
            eval_phase=True,
            policy_to_actions_lambda=(lambda policy: policy.mode()),
            **collect_kwargs))
    return (train_summary, eval_collect_summary, (train_initialization,
                                                  eval_initialization))
  else:
    return (train_summary, None, (train_initialization,))
Exemplo n.º 4
0
def define_train(hparams, environment_spec, event_dir):
    """Define the training setup."""
    policy_lambda = hparams.network

    if environment_spec == "stacked_pong":
        environment_spec = lambda: gym.make("PongNoFrameskip-v4")
        wrappers = hparams.in_graph_wrappers if hasattr(
            hparams, "in_graph_wrappers") else []
        wrappers.append((tf_atari_wrappers.MaxAndSkipWrapper, {"skip": 4}))
        hparams.in_graph_wrappers = wrappers
    if isinstance(environment_spec, str):
        env_lambda = lambda: gym.make(environment_spec)
    else:
        env_lambda = environment_spec

    batch_env = utils.batch_env_factory(env_lambda,
                                        hparams,
                                        num_agents=hparams.num_agents)

    policy_factory = functools.partial(policy_lambda, batch_env.action_space,
                                       hparams)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        memory, collect_summary = collect.define_collect(
            policy_factory,
            batch_env,
            hparams,
            eval_phase=False,
            on_simulated=hparams.simulated_environment)
        ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
        summary = tf.summary.merge([collect_summary, ppo_summary])

    with tf.variable_scope("eval", reuse=tf.AUTO_REUSE):
        eval_env_lambda = env_lambda
        if event_dir and hparams.video_during_eval:
            # Some environments reset environments automatically, when reached done
            # state. For them we shall record only every second episode.
            d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
            eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
                env_lambda(),
                event_dir,
                video_callable=lambda i: i % d == 0)
            eval_env_lambda = (
                lambda: utils.EvalVideoWrapper(eval_env_lambda()))
        eval_batch_env = utils.batch_env_factory(
            eval_env_lambda,
            hparams,
            num_agents=hparams.num_eval_agents,
            xvfb=hparams.video_during_eval)

        # TODO(blazej0): correct to the version below.
        corrected = True
        eval_summary = tf.no_op()
        if corrected:
            _, eval_summary = collect.define_collect(policy_factory,
                                                     eval_batch_env,
                                                     hparams,
                                                     eval_phase=True)
    return summary, eval_summary
Exemplo n.º 5
0
def define_train(hparams):
    """Define the training setup."""
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        memory, collect_summary, initialization\
          = collect.define_collect(
              hparams, "ppo_train", eval_phase=False)
        ppo_summary = ppo.define_ppo_epoch(memory, hparams)
        summary = tf.summary.merge([collect_summary, ppo_summary])

    return summary, None, initialization
Exemplo n.º 6
0
def define_train(hparams, event_dir):
  """Define the training setup."""
  del event_dir
  with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    memory, collect_summary = collect.define_collect(
        hparams, "ppo_train", eval_phase=False,
        on_simulated=hparams.simulated_environment)
    ppo_summary = ppo.define_ppo_epoch(memory, hparams)
    summary = tf.summary.merge([collect_summary, ppo_summary])

  return summary, None
Exemplo n.º 7
0
def define_train(hparams, event_dir):
  """Define the training setup."""
  del event_dir
  with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    memory, collect_summary, initialization\
      = collect.define_collect(
          hparams, "ppo_train", eval_phase=False)
    ppo_summary = ppo.define_ppo_epoch(memory, hparams)
    summary = tf.summary.merge([collect_summary, ppo_summary])

  return summary, None, initialization
Exemplo n.º 8
0
def define_train(hparams):
    """Define the training setup."""
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        memory, collect_summary, train_initialization = collect.define_collect(
            hparams, "ppo_train", eval_phase=False)
        ppo_summary = ppo.define_ppo_epoch(memory, hparams)
        train_summary = tf.summary.merge([collect_summary, ppo_summary])

        if hparams.eval_every_epochs:
            _, eval_collect_summary, eval_initialization = collect.define_collect(
                hparams, "ppo_eval", eval_phase=True)
            return train_summary, eval_collect_summary, (train_initialization,
                                                         eval_initialization)
        else:
            return train_summary, None, (train_initialization, )
def define_train(hparams, environment_name):
    """Define the training setup."""
    env_lambda = lambda: gym.make(environment_name)
    policy_lambda = hparams.network
    env = env_lambda()
    action_space = env.action_space

    batch_env = utils.define_batch_env(env_lambda, hparams.num_agents)

    policy_factory = tf.make_template(
        "network", functools.partial(policy_lambda, action_space, hparams))

    memory, collect_summary = collect.define_collect(policy_factory, batch_env,
                                                     hparams)
    ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
    summary = tf.summary.merge([collect_summary, ppo_summary])

    return summary
Exemplo n.º 10
0
def define_train(hparams, environment_spec, event_dir):
    """Define the training setup."""
    if isinstance(environment_spec, str):
        env_lambda = lambda: gym.make(environment_spec)
    else:
        env_lambda = environment_spec
    policy_lambda = hparams.network
    env = env_lambda()
    action_space = env.action_space

    batch_env = utils.define_batch_env(env_lambda, hparams.num_agents)

    policy_factory = tf.make_template(
        "network", functools.partial(policy_lambda, action_space, hparams))

    with tf.variable_scope("train"):
        memory, collect_summary = collect.define_collect(policy_factory,
                                                         batch_env,
                                                         hparams,
                                                         eval_phase=False)
    ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
    summary = tf.summary.merge([collect_summary, ppo_summary])

    with tf.variable_scope("eval"):
        eval_env_lambda = env_lambda
        if event_dir and hparams.video_during_eval:
            # Some environments reset environments automatically, when reached done
            # state. For them we shall record only every second episode.
            d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
            eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
                env_lambda(),
                event_dir,
                video_callable=lambda i: i % d == 0)
        wrapped_eval_env_lambda = lambda: utils.EvalVideoWrapper(
            eval_env_lambda())
        _, eval_summary = collect.define_collect(
            policy_factory,
            utils.define_batch_env(wrapped_eval_env_lambda,
                                   hparams.num_eval_agents,
                                   xvfb=hparams.video_during_eval),
            hparams,
            eval_phase=True)
    return summary, eval_summary
Exemplo n.º 11
0
def define_train(hparams, environment_spec, event_dir):
  """Define the training setup."""
  if isinstance(environment_spec, str):
    env_lambda = lambda: gym.make(environment_spec)
  else:
    env_lambda = environment_spec
  policy_lambda = hparams.network
  env = env_lambda()
  action_space = env.action_space

  batch_env = utils.define_batch_env(env_lambda, hparams.num_agents)

  policy_factory = tf.make_template(
      "network",
      functools.partial(policy_lambda, action_space, hparams))

  with tf.variable_scope("train"):
    memory, collect_summary = collect.define_collect(
        policy_factory, batch_env, hparams, eval_phase=False)
  ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
  summary = tf.summary.merge([collect_summary, ppo_summary])

  with tf.variable_scope("eval"):
    eval_env_lambda = env_lambda
    if event_dir and hparams.video_during_eval:
      # Some environments reset environments automatically, when reached done
      # state. For them we shall record only every second episode.
      d = 2 if env_lambda().metadata.get("semantics.autoreset") else 1
      eval_env_lambda = lambda: gym.wrappers.Monitor(  # pylint: disable=g-long-lambda
          env_lambda(), event_dir, video_callable=lambda i: i % d == 0)
    wrapped_eval_env_lambda = lambda: utils.EvalVideoWrapper(eval_env_lambda())
    _, eval_summary = collect.define_collect(
        policy_factory,
        utils.define_batch_env(wrapped_eval_env_lambda, hparams.num_eval_agents,
                               xvfb=hparams.video_during_eval),
        hparams, eval_phase=True)
  return summary, eval_summary