예제 #1
0
def make_env(batch_size=8):
    """Creates the env."""

    # No resizing needed, so let's be on the normal EnvProblem.
    if not FLAGS.resize:  # None or False
        return env_problem.EnvProblem(base_env_name=FLAGS.env_problem_name,
                                      batch_size=batch_size,
                                      reward_range=(-1, 1))

    max_timestep = None
    try:
        max_timestep = int(FLAGS.max_timestep)
    except Exception:  # pylint: disable=broad-except
        pass

    wrapper_fn = functools.partial(
        gym_utils.gym_env_wrapper, **{
            "rl_env_max_episode_steps": max_timestep,
            "maxskip_env": True,
            "rendered_env": True,
            "rendered_env_resize_to":
            (FLAGS.resized_height, FLAGS.resized_width),
            "sticky_actions": False,
            "output_dtype": onp.int32 if FLAGS.use_tpu else None,
        })

    return rendered_env_problem.RenderedEnvProblem(
        base_env_name=FLAGS.env_problem_name,
        batch_size=batch_size,
        env_wrapper_fn=wrapper_fn,
        reward_range=(-1, 1))
예제 #2
0
def make_env(name, batch_size, max_timestep, clip_rewards, rendered_env,
             resize_dims, **env_kwargs):
    """Creates the env."""

    if clip_rewards:
        env_kwargs.update({"reward_range": (-1, 1), "discrete_rewards": True})
    else:
        env_kwargs.update({"discrete_rewards": False})

    # TODO(afrozm): Should we leave out some cores?
    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

    # No resizing needed, so let's be on the normal EnvProblem.
    if not rendered_env:
        return gym_env_problem.GymEnvProblem(base_env_name=name,
                                             batch_size=batch_size,
                                             parallelism=parallelism,
                                             **env_kwargs)

    wrapper_fn = functools.partial(
        gym_utils.gym_env_wrapper, **{
            "rl_env_max_episode_steps": max_timestep,
            "maxskip_env": True,
            "rendered_env": True,
            "rendered_env_resize_to": resize_dims,
            "sticky_actions": False,
            "output_dtype": onp.int32 if FLAGS.use_tpu else None,
        })

    return rendered_env_problem.RenderedEnvProblem(base_env_name=name,
                                                   batch_size=batch_size,
                                                   parallelism=parallelism,
                                                   env_wrapper_fn=wrapper_fn,
                                                   **env_kwargs)
예제 #3
0
def make_env():
    """Creates the env."""
    if FLAGS.env_name:
        return gym.make(FLAGS.env_name)

    assert FLAGS.env_problem_name

    # No resizing needed, so let's be on the normal EnvProblem.
    if not FLAGS.resize:  # None or False
        return env_problem.EnvProblem(base_env_name=FLAGS.env_problem_name,
                                      batch_size=FLAGS.batch_size,
                                      reward_range=(-1, 1))

    wrapper_fn = functools.partial(
        gym_utils.gym_env_wrapper, **{
            "rl_env_max_episode_steps": FLAGS.max_timestep,
            "maxskip_env": True,
            "rendered_env": True,
            "rendered_env_resize_to":
            (FLAGS.resized_height, FLAGS.resized_width),
            "sticky_actions": False
        })

    return rendered_env_problem.RenderedEnvProblem(
        base_env_name=FLAGS.env_problem_name,
        batch_size=FLAGS.batch_size,
        env_wrapper_fn=wrapper_fn,
        reward_range=(-1, 1))
예제 #4
0
def make_env(batch_size=1,
             env_problem_name="",
             resize=True,
             resize_dims=(105, 80),
             max_timestep="None",
             clip_rewards=True,
             parallelism=1,
             use_tpu=False,
             num_actions=None,
             rendered_env=True,
             **env_kwargs):
    """Creates the env."""

    if clip_rewards:
        env_kwargs.update({"reward_range": (-1, 1), "discrete_rewards": True})
    else:
        env_kwargs.update({"discrete_rewards": False})

    # TODO(henrykm) - below someone linked "resize" with "abnormality"
    # Probably we need more nuanced concept of "abnormality"
    # decoupled from "resize". Currently the resize flag implies
    # that we switch from a generic env to a wrapped env.
    # Overall this file and gym_utils.py look like good candidates
    # for a refactor.

    # No resizing needed, so let's be on the normal EnvProblem.
    if not resize:  # None or False
        return gym_env_problem.GymEnvProblem(base_env_name=env_problem_name,
                                             batch_size=batch_size,
                                             parallelism=parallelism,
                                             **env_kwargs)

    try:
        max_timestep = int(max_timestep)
    except Exception:  # pylint: disable=broad-except
        max_timestep = None

    wrapper_fn = functools.partial(
        gym_utils.gym_env_wrapper, **{
            "rl_env_max_episode_steps": max_timestep,
            "maxskip_env": True,
            "rendered_env": rendered_env,
            "rendered_env_resize_to": resize_dims,
            "sticky_actions": False,
            "output_dtype": np.int32 if use_tpu else None,
            "num_actions": num_actions,
        })

    return rendered_env_problem.RenderedEnvProblem(
        base_env_name=env_problem_name,
        batch_size=batch_size,
        parallelism=parallelism,
        rendered_env=rendered_env,
        env_wrapper_fn=wrapper_fn,
        **env_kwargs)
예제 #5
0
def make_env(batch_size=1,
             env_problem_name="",
             resize=True,
             resized_height=105,
             resized_width=80,
             max_timestep="None",
             clip_rewards=True,
             parallelism=1,
             use_tpu=False,
             **env_kwargs):
  """Creates the env."""

  if clip_rewards:
    env_kwargs.update({"reward_range": (-1, 1), "discrete_rewards": True})
  else:
    env_kwargs.update({"discrete_rewards": False})

  # No resizing needed, so let's be on the normal EnvProblem.
  if not resize:  # None or False
    return gym_env_problem.GymEnvProblem(
        base_env_name=env_problem_name,
        batch_size=batch_size,
        parallelism=parallelism,
        **env_kwargs)

  try:
    max_timestep = int(max_timestep)
  except Exception:  # pylint: disable=broad-except
    max_timestep = None

  wrapper_fn = functools.partial(
      gym_utils.gym_env_wrapper, **{
          "rl_env_max_episode_steps": max_timestep,
          "maxskip_env": True,
          "rendered_env": True,
          "rendered_env_resize_to": (resized_height, resized_width),
          "sticky_actions": False,
          "output_dtype": np.int32 if use_tpu else None,
      })

  return rendered_env_problem.RenderedEnvProblem(
      base_env_name=env_problem_name,
      batch_size=batch_size,
      parallelism=parallelism,
      env_wrapper_fn=wrapper_fn,
      **env_kwargs)
예제 #6
0
파일: ppo_main.py 프로젝트: tianhai123/-
def make_env(batch_size=8, **env_kwargs):
    """Creates the env."""

    if FLAGS.clip_rewards:
        env_kwargs.update({"reward_range": (-1, 1), "discrete_rewards": True})
    else:
        env_kwargs.update({"discrete_rewards": False})

    # TODO(afrozm): Should we leave out some cores?
    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

    # No resizing needed, so let's be on the normal EnvProblem.
    if not FLAGS.resize:  # None or False
        return gym_env_problem.GymEnvProblem(
            base_env_name=FLAGS.env_problem_name,
            batch_size=batch_size,
            parallelism=parallelism,
            **env_kwargs)

    max_timestep = None
    try:
        max_timestep = int(FLAGS.max_timestep)
    except Exception:  # pylint: disable=broad-except
        pass

    wrapper_fn = functools.partial(
        gym_utils.gym_env_wrapper, **{
            "rl_env_max_episode_steps": max_timestep,
            "maxskip_env": True,
            "rendered_env": True,
            "rendered_env_resize_to":
            (FLAGS.resized_height, FLAGS.resized_width),
            "sticky_actions": False,
            "output_dtype": onp.int32 if FLAGS.use_tpu else None,
        })

    return rendered_env_problem.RenderedEnvProblem(
        base_env_name=FLAGS.env_problem_name,
        batch_size=batch_size,
        parallelism=parallelism,
        env_wrapper_fn=wrapper_fn,
        **env_kwargs)