Example #1
0
def env_load_fn(environment_name,
                max_episode_steps=None,
                resize_factor=1,
                action_range=1.0,
                action_noise=0.0,
                threshold_dist=1.0,
                terminate_on_timeout=False):
    gym_env = PointEnv(walls=environment_name,
                       resize_factor=resize_factor,
                       action_range=action_range,
                       action_noise=action_noise)
    gym_env = GoalConditionedPointWrapper(gym_env,
                                          threshold_dist=threshold_dist)
    env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True)

    if max_episode_steps > 0:
        if terminate_on_timeout:  # test 超时的那个ts改成last
            env = wrappers.TimeLimit(env, max_episode_steps)
        else:  # train 超时不会把那个ts改成last
            env = NonTerminatingTimeLimit(env, max_episode_steps)

    # tf_env.pyenv.envs[0].gym    <GoalConditionedPointWrapper<PointEnv instance>>
    # tf_env.pyenv.envs                   [<__main__.NonTerminatingTimeLimit at 0x7f9329d18080>]
    # tf_env.pyenv                              <tf_agents.environments.batched_py_environment.BatchedPyEnvironment at 0x7f9329d180f0>
    return tf_py_environment.TFPyEnvironment(env)
Example #2
0
  def test_pad_short_episode_upto_fixed_length(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.FixedLength(wrappers.TimeLimit(env, 2), 3)

    time_step = env.reset()
    self.assertTrue(time_step.is_first())
    self.assertEqual(1.0, time_step.discount)

    # Normal Step
    time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(time_step.is_mid())
    self.assertEqual(1.0, time_step.discount)

    # TimeLimit truncated.
    time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(time_step.is_last())
    self.assertEqual(1.0, time_step.discount)

    # Padded with discount 0.
    time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(time_step.is_last())
    self.assertEqual(0.0, time_step.discount)

    # Restart episode after fix length.
    time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(time_step.is_first())
    self.assertEqual(1.0, time_step.discount)
def load(environment_name,
         discount=1.0,
         max_episode_steps=None,
         gym_env_wrappers=(),
         env_wrappers=(),
         spec_dtype_map=None,
         gym_kwargs=None,
         auto_reset=True):
    """Loads the selected environment and wraps it with the specified wrappers.

  Note that by default a TimeLimit wrapper is used to limit episode lengths
  to the default benchmarks defined by the registered environments.

  Args:
    environment_name: Name for the environment to load.
    discount: Discount to use for the environment.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no max_episode_steps set in the environment's spec.
    gym_env_wrappers: Iterable with references to wrapper classes to use
      directly on the gym environment.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.
    spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
      default dtype for the tensors. An easy way to configure a custom
      mapping through Gin is to define a gin-configurable function that returns
      desired mapping and call it in your Gin config file, for example:
      `suite_gym.load.spec_dtype_map = @get_custom_mapping()`.
    gym_kwargs: The kwargs to pass to the Gym environment class.
    auto_reset: If True (default), reset the environment automatically after a
      terminal state is reached.

  Returns:
    A PyEnvironment instance.
  """
    gym_kwargs = gym_kwargs if gym_kwargs else {}
    gym_spec = gym.spec(environment_name)
    gym_env = gym_spec.make(**gym_kwargs)

    if max_episode_steps is None and gym_spec.max_episode_steps is not None:
        max_episode_steps = gym_spec.max_episode_steps

    for wrapper in gym_env_wrappers:
        gym_env = wrapper(gym_env)

    env = AdversarialGymWrapper(
        gym_env,
        discount=discount,
        spec_dtype_map=spec_dtype_map,
        auto_reset=auto_reset,
    )

    if max_episode_steps is not None and max_episode_steps > 0:
        env = wrappers.TimeLimit(env, max_episode_steps)

    for wrapper in env_wrappers:
        env = wrapper(env)

    return env
Example #4
0
  def test_extra_env_methods_work(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)

    self.assertIsNone(env.get_info())
    env.reset()
    env.step(np.array(0, dtype=np.int32))
    self.assertEqual({}, env.get_info())
Example #5
0
  def test_extra_env_methods_work(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)

    self.assertEqual(None, env.get_info())
    env.reset()
    env.step(0)
    self.assertEqual({}, env.get_info())
 def get_env(add_curiosity_reward=True):
     """Return a copy of the environment."""
     env = VectorIncrementEnvironmentTFAgents(v_n=v_n,
                                              v_k=v_k,
                                              v_seed=v_seed,
                                              do_transform=do_transform)
     env = wrappers.TimeLimit(env, time_limit)
     if add_curiosity_reward:
         env = CuriosityWrapper(env, env_model, alpha=alpha)
     env = tf_py_environment.TFPyEnvironment(env)
     return env
Example #7
0
  def test_limit_duration_stops_after_duration(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)

    env.reset()
    env.step(np.array(0, dtype=np.int32))
    time_step = env.step(np.array(0, dtype=np.int32))

    self.assertTrue(time_step.is_last())
    self.assertNotEqual(None, time_step.discount)
    self.assertNotEqual(0.0, time_step.discount)
Example #8
0
  def test_episode_count_with_time_limit(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)
    env = wrappers.RunStats(env)

    env.reset()
    self.assertEqual(0, env.episodes)

    env.step(np.array(0, dtype=np.int32))
    time_step = env.step(np.array(0, dtype=np.int32))

    self.assertTrue(time_step.is_last())
    self.assertEqual(1, env.episodes)
Example #9
0
def load(
        domain_name,
        task_name,
        task_kwargs=None,
        environment_kwargs=None,
        env_load_fn=suite.load,  # use custom_suite.load for customized env
        action_repeat_wrapper=wrappers.ActionRepeat,
        action_repeat=1,
        frame_stack=4,
        episode_length=1000,
        actions_in_obs=True,
        rewards_in_obs=False,
        pixels_obs=True,
        # Render params
        grayscale=False,
        visualize_reward=False,
        render_kwargs=None):
    """Returns an environment from a domain name, task name."""
    env = env_load_fn(domain_name,
                      task_name,
                      task_kwargs=task_kwargs,
                      environment_kwargs=environment_kwargs,
                      visualize_reward=visualize_reward)
    if pixels_obs:
        env = pixel_wrapper.Wrapper(env,
                                    pixels_only=False,
                                    render_kwargs=render_kwargs)

    env = dm_control_wrapper.DmControlWrapper(env, render_kwargs)

    if pixels_obs and grayscale:
        env = GrayscaleWrapper(env)
    if action_repeat > 1:
        env = action_repeat_wrapper(env, action_repeat)
    if pixels_obs:
        env = FrameStack(env, frame_stack, actions_in_obs, rewards_in_obs)
    else:
        env = FlattenState(env)

    # Adjust episode length based on action_repeat
    max_episode_steps = (episode_length + action_repeat - 1) // action_repeat

    # Apply a time limit wrapper at the end to properly trigger all reset()
    env = wrappers.TimeLimit(env, max_episode_steps)
    return env
Example #10
0
  def test_limit_duration_wrapped_env_forwards_calls(self):
    cartpole_env = gym.spec('CartPole-v1').make()
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 10)

    action_spec = env.action_spec()
    self.assertEqual((), action_spec.shape)
    self.assertEqual(0, action_spec.minimum)
    self.assertEqual(1, action_spec.maximum)

    observation_spec = env.observation_spec()
    self.assertEqual((4,), observation_spec.shape)
    high = np.array([
        4.8,
        np.finfo(np.float32).max, 2 / 15.0 * math.pi,
        np.finfo(np.float32).max
    ])
    np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
    np.testing.assert_array_almost_equal(high, observation_spec.maximum)
Example #11
0
  def test_duration_applied_after_episode_terminates_early(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 10000)

    # Episode 1 stepped until termination occurs.
    time_step = env.step(np.array(1, dtype=np.int32))
    while not time_step.is_last():
      time_step = env.step(np.array(1, dtype=np.int32))

    self.assertTrue(time_step.is_last())
    env._duration = 2

    # Episode 2 short duration hits step limit.
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())
Example #12
0
  def test_automatic_reset(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)

    # Episode 1
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())

    # Episode 2
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())
def load_dm_env(env_name,
                frame_shape=(84, 84, 3),
                episode_length=1000,
                action_repeat=4,
                frame_stack=3,
                task_kwargs=None,
                render_kwargs=None,
                camera_kwargs=None,
                background_kwargs=None,
                color_kwargs=None,
                stack_within_repeat=False):
    """Returns an environment from a domain name, task name."""
    domain_name, task_name = env_name.split('-')
    logging.info('Loading environment.')
    render_kwargs = render_kwargs or {}
    render_kwargs['width'] = frame_shape[0]
    render_kwargs['height'] = frame_shape[1]

    if 'camera_id' not in render_kwargs:
        render_kwargs['camera_id'] = 2 if domain_name == 'quadruped' else 0
    if camera_kwargs and 'camera_id' not in camera_kwargs:
        camera_kwargs['camera_id'] = 2 if domain_name == 'quadruped' else 0

    env = load_pixels(domain_name,
                      task_name,
                      task_kwargs=task_kwargs,
                      render_kwargs=render_kwargs,
                      camera_kwargs=camera_kwargs,
                      background_kwargs=background_kwargs,
                      color_kwargs=color_kwargs)

    env = FrameStackActionRepeatWrapper(
        env,
        action_repeat=action_repeat,
        stack_size=frame_stack,
        stack_within_repeat=stack_within_repeat)

    # Shorten episode length
    max_episode_steps = (episode_length + action_repeat - 1) // action_repeat
    env = wrappers.TimeLimit(env, max_episode_steps)
    return env
Example #14
0
def env_load_fn(environment_name,
				 max_episode_steps=None,
				 resize_factor=1,
				 gym_env_wrappers=(GoalConditionedPointWrapper,),
				 terminate_on_timeout=False):
	"""Loads the selected environment and wraps it with the specified wrappers.

	Args:
		environment_name: Name for the environment to load.
		max_episode_steps: If None the max_episode_steps will be set to the default
			step limit defined in the environment's spec. No limit is applied if set
			to 0 or if there is no timestep_limit set in the environment's spec.
		gym_env_wrappers: Iterable with references to wrapper classes to use
			directly on the gym environment.
		terminate_on_timeout: Whether to set done = True when the max episode
			steps is reached.

	Returns:
		A PyEnvironmentBase instance.
	"""
	gym_env = PointEnv(walls=environment_name,
										 resize_factor=resize_factor)
	
	for wrapper in gym_env_wrappers:
		gym_env = wrapper(gym_env)
	env = gym_wrapper.GymWrapper(
			gym_env,
			discount=1.0,
			auto_reset=True,
	)

	if max_episode_steps > 0:
		if terminate_on_timeout:
			env = wrappers.TimeLimit(env, max_episode_steps)
		else:
			env = NonTerminatingTimeLimit(env, max_episode_steps)

	return tf_py_environment.TFPyEnvironment(env)
Example #15
0
eval_duration = 200 # @param

# Split data into training and test set
prices = pd.read_csv(CSV_PATH, parse_dates=True, index_col=0)
train = prices[:date_split]
test = prices[date_split:]

# Create a feature list:
# Push start date forward by features_length to have non-zero initial features
# Cleanup logic here
features_length = 20
train_features = Features(train, features_length)
test_features = Features(test, features_length)

# Create Environments
train_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, train_features), duration=training_duration)
eval_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, test_features), duration=eval_duration)
test_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, test_features), duration=len(test)-features_length-1)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
test_env = tf_py_environment.TFPyEnvironment(test_py_env)

# Initialize Q Network
q_net = q_network.QNetwork(
  train_env.observation_spec(),
  train_env.action_spec(),
  fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)    
Example #16
0
train_py_env = gym_wrapper.GymWrapper(
      ChangeRewardMountainCarEnv(),
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
eval_py_env = gym_wrapper.GymWrapper(
      ChangeRewardMountainCarEnv(),
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
train_py_env = wrappers.TimeLimit(train_py_env, duration=200)
eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

RL_train(train_env, eval_env, fc_layer_params = (48,64,), name = '_train')

"""Set num_iterations to 50000+ will let agent converge to less than 110 steps"""

iterations = range(len(returns))
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')

iterations = range(len(steps))
Example #17
0
else:
    num_walls = 0

#Check if string from input is a valid digit before interpreting it (str.isdigit). Got help from https://stackoverflow.com/questions/1265665/how-can-i-check-if-a-string-represents-an-int-without-using-try-except
if(str.isdigit(sys.argv[3])):
    num_agents = int(sys.argv[3])
else:
    num_agents = 0

#Check if string from input is a valid digit before interpreting it (str.isdigit). Got help from https://stackoverflow.com/questions/1265665/how-can-i-check-if-a-string-represents-an-int-without-using-try-except
if(str.isdigit(sys.argv[4])):
    def_agents = int(sys.argv[4])
else:
    def_agents = 0

py_env = wrappers.TimeLimit(CTFEnv(grid_size, 512, num_walls, num_agents, def_agents), duration=100)

env = tf_py_environment.TFPyEnvironment(py_env)

policy = tf.saved_model.load(model_path)

#Video of 5 simulations, written by Josh Gendein, borrowed and modified from https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
with imageio.get_writer(filename, fps=15) as video:
    for _ in range(5):
        time_step = env.reset()
        video.append_data(py_env.render())
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = env.step(action_step.action)
            video.append_data(py_env.render())
Example #18
0
from tf_agents.experimental.train.utils import train_utils
from tf_agents.metrics import py_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils

tempdir = tempfile.gettempdir()

fc_layer_params = (100, )
importance_ratio_clipping
lambda_value

train_timed_env = wrappers.TimeLimit(ArmEnv(), 1000)

eval_timed_env = wrappers.TimeLimit(ArmEnv(), 1000)

train_env = tf_py_environment(train_timed_env)

eval_env = tf_py_environment(eval_timed_env)

observation_tensor_spec, action_spec, time_step_tensor_spec = (
    spec_utils.get_tensor_specs(train_env))
normalized_observation_tensor_spec = tf.nest.map_structure(
    lambda s: tf.TensorSpec(dtype=tf.float32, shape=s.shape, name=s.name),
    observation_tensor_spec)

actor_net = actor_distribution_network.ActorDistributionNetwork(
    normalized_observation_tensor_spec, ...)
Example #19
0
from ai.PastureEnvironment import PastureEnvironment
from pasture.animal.sheep.Sheep import Sheep
from pasture.animal.shepherd.Shepherd import Shepherd

sheep_list = [Sheep(2, 2, 2), Sheep(4, 4, 2)]
shepherd_list = [Shepherd(6, 6)]

pasture_engine = PastureEngine(size=8,
                               starting_shepherds_list=shepherd_list,
                               starting_sheep_list=sheep_list,
                               target=(1, 1))

pasture_environment = PastureEnvironment(pasture_engine)
utils.validate_py_environment(pasture_environment, episodes=5)

pasture_env_wrapped = wrappers.TimeLimit(pasture_environment, duration=15)
print(pasture_env_wrapped)

train_tf_env = tf_py_environment.TFPyEnvironment(pasture_env_wrapped)
print(train_tf_env)
eval_tf_env = tf_py_environment.TFPyEnvironment(pasture_env_wrapped)
print(eval_tf_env)

fc_layer_params = [32, 64, 128]

q_net = q_network.QNetwork(
    train_tf_env.observation_spec(),  # input
    train_tf_env.action_spec(),  # output
    fc_layer_params=fc_layer_params  # layerz
)
import numpy as np
import rospy
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
from arm_pyenv import ArmEnv
# source devel/setup.bash
# roslaunch arm_bringup sim_bringup.launch world:=empty
rospy.init_node("test")
tf.compat.v1.enable_v2_behavior()

environment = ArmEnv()

timed_env = wrappers.TimeLimit(
    environment,
    900
)
utils.validate_py_environment(timed_env, episodes=5)
print('action_spec:', environment.action_spec())
print('time_step_spec:', environment.time_step_spec())
print('time_step_spec.observation:', environment.time_step_spec().observation)
print('time_step_spec.step_type:', environment.time_step_spec().step_type)
print('time_step_spec.discount:', environment.time_step_spec().discount)
print('time_step_spec.reward:', environment.time_step_spec().reward)

tf_env = tf_py_environment.TFPyEnvironment(env)
print(isinstance(tf_env, tf_environment.TFEnvironment))
print("TimeStep Specs:", tf_env.time_step_spec())
print("Action Specs:", tf_env.action_spec())

#Usually two environments are instantiated:
# one for training and one for evaluation.

# First we will load the Gridworld environments 
# into a TimeLimit Wrapper which terminates the game 
# if 10 steps are reached. 
# The results are then wrapped in the TF environment handlers.


train_py_env = wrappers.TimeLimit(env, duration=100) #change duration later if neded
eval_py_env = wrappers.TimeLimit(env, duration=100)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

"""The DQN agent can be used in any environment which has a discrete action space.
At the heart of a DQN Agent is a QNetwork, a neural network model that can learn to predict QValues (expected returns) for all actions, given an observation from the environment.
Use tf_agents.networks.q_network to create a QNetwork, passing in the observation_spec, action_spec, and a tuple describing the number and size of the model's hidden layers.
"""

# Next, setup an agent, a learning model, replay buffer, driver.
# Connect them together

#right now, creating a network with single hidden layer of 100 nodes
fc_layer_params = (100,)
Example #22
0
        if action == 2:  #left
            if col - 1 >= 0:
                self._state[1] -= 1
        if action == 3:  #right
            if col + 1 < 6:
                self._state[1] += 1

    def game_over(self):
        row, col, frow, fcol = self._state[0], self._state[1], self._state[
            2], self._state[3]
        return row == frow and col == fcol


if __name__ == '__main__':
    env = GridWorldEnv()
    utils.validate_py_environment(env, episodes=5)

    tl_env = wrappers.TimeLimit(env, duration=50)

    time_step = tl_env.reset()
    print(time_step)
    rewards = time_step.reward

    for i in range(100):
        action = np.random.choice([0, 1, 2, 3])
        time_step = tl_env.step(action)
        print(time_step)
        rewards += time_step.reward

    print(rewards)
Example #23
0
num_iterations = 40000  # @param

initial_collect_steps = 1000  # @param
collect_steps_per_iteration = 1  # @param
replay_buffer_capacity = 100000  # @param

fc_layer_params = (100, )

batch_size = 128  # @param
learning_rate = 1e-5  # @param
log_interval = 200  # @param

num_eval_episodes = 2  # @param
eval_interval = 1000  # @param

train_py_env = wrappers.TimeLimit(ClusterEnv(), duration=100)
eval_py_env = wrappers.TimeLimit(ClusterEnv(), duration=100)
# train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)
# eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)
Example #24
0
batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-5  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}
num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

#Parameters for agents and walls
grid_size = int(sys.argv[1])
num_walls = int(sys.argv[2])
num_agents = int(sys.argv[3])
def_agents = int(sys.argv[4])
c = CTFEnv(grid_size, 512, num_walls, num_agents, def_agents)

#Training environment, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
#Simulation will last 200 steps
train_py_env = wrappers.TimeLimit(c, duration=200)
eval_py_env = wrappers.TimeLimit(c, duration=200)

#Training environment, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

#Q Network, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
fc_layer_params = (2000, )
q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

#Q Network Optimizer, written by Josh Gendein, borrowed from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
Example #25
0
num_iterations = 10000  # @param

initial_collect_steps = 1000  # @param
collect_steps_per_iteration = 1  # @param
replay_buffer_capacity = 100000  # @param

fc_layer_params = (100, )

batch_size = 128  # @param
learning_rate = 1e-5  # @param
log_interval = 200  # @param

num_eval_episodes = 2  # @param
eval_interval = 1000  # @param

train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)
eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
Example #26
0
def _load_dm_env(domain_name,
                 task_name,
                 pixels,
                 action_repeat,
                 max_episode_steps=None,
                 obs_type='pixels',
                 distractor=False):
    """Load a Deepmind control suite environment."""
    try:
        if not pixels:
            env = suite_dm_control.load(domain_name=domain_name,
                                        task_name=task_name)
            if action_repeat > 1:
                env = wrappers.ActionRepeat(env, action_repeat)

        else:

            def wrap_repeat(env):
                return ActionRepeatDMWrapper(env, action_repeat)

            camera_id = 2 if domain_name == 'quadruped' else 0

            pixels_only = obs_type == 'pixels'
            if distractor:
                render_kwargs = dict(width=84, height=84, camera_id=camera_id)

                env = distractor_suite.load(
                    domain_name,
                    task_name,
                    difficulty='hard',
                    dynamic=False,
                    background_dataset_path='DAVIS/JPEGImages/480p/',
                    task_kwargs={},
                    environment_kwargs={},
                    render_kwargs=render_kwargs,
                    visualize_reward=False,
                    env_state_wrappers=[wrap_repeat])

                # env = wrap_repeat(env)

                # env = suite.wrappers.pixels.Wrapper(
                #     env,
                #     pixels_only=pixels_only,
                #     render_kwargs=render_kwargs,
                #     observation_key=obs_type)

                env = dm_control_wrapper.DmControlWrapper(env, render_kwargs)

            else:
                env = suite_dm_control.load_pixels(
                    domain_name=domain_name,
                    task_name=task_name,
                    render_kwargs=dict(width=84,
                                       height=84,
                                       camera_id=camera_id),
                    env_state_wrappers=[wrap_repeat],
                    observation_key=obs_type,
                    pixels_only=pixels_only)

        if action_repeat > 1 and max_episode_steps is not None:
            # Shorten episode length.
            max_episode_steps = (max_episode_steps + action_repeat -
                                 1) // action_repeat
            env = wrappers.TimeLimit(env, max_episode_steps)

        return env

    except ValueError as e:
        logging.warning(
            'cannot instantiate dm env: domain_name=%s, task_name=%s',
            domain_name, task_name)
        logging.warning('Supported domains and tasks: %s',
                        str({
                            key: list(val.SUITE.keys())
                            for key, val in suite._DOMAINS.items()
                        }))  # pylint: disable=protected-access
        raise e