def env_load_fn(environment_name, max_episode_steps=None, resize_factor=1, action_range=1.0, action_noise=0.0, threshold_dist=1.0, terminate_on_timeout=False): gym_env = PointEnv(walls=environment_name, resize_factor=resize_factor, action_range=action_range, action_noise=action_noise) gym_env = GoalConditionedPointWrapper(gym_env, threshold_dist=threshold_dist) env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True) if max_episode_steps > 0: if terminate_on_timeout: # test 超时的那个ts改成last env = wrappers.TimeLimit(env, max_episode_steps) else: # train 超时不会把那个ts改成last env = NonTerminatingTimeLimit(env, max_episode_steps) # tf_env.pyenv.envs[0].gym <GoalConditionedPointWrapper<PointEnv instance>> # tf_env.pyenv.envs [<__main__.NonTerminatingTimeLimit at 0x7f9329d18080>] # tf_env.pyenv <tf_agents.environments.batched_py_environment.BatchedPyEnvironment at 0x7f9329d180f0> return tf_py_environment.TFPyEnvironment(env)
def test_pad_short_episode_upto_fixed_length(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.FixedLength(wrappers.TimeLimit(env, 2), 3) time_step = env.reset() self.assertTrue(time_step.is_first()) self.assertEqual(1.0, time_step.discount) # Normal Step time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_mid()) self.assertEqual(1.0, time_step.discount) # TimeLimit truncated. time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_last()) self.assertEqual(1.0, time_step.discount) # Padded with discount 0. time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_last()) self.assertEqual(0.0, time_step.discount) # Restart episode after fix length. time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_first()) self.assertEqual(1.0, time_step.discount)
def load(environment_name, discount=1.0, max_episode_steps=None, gym_env_wrappers=(), env_wrappers=(), spec_dtype_map=None, gym_kwargs=None, auto_reset=True): """Loads the selected environment and wraps it with the specified wrappers. Note that by default a TimeLimit wrapper is used to limit episode lengths to the default benchmarks defined by the registered environments. Args: environment_name: Name for the environment to load. discount: Discount to use for the environment. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no max_episode_steps set in the environment's spec. gym_env_wrappers: Iterable with references to wrapper classes to use directly on the gym environment. env_wrappers: Iterable with references to wrapper classes to use on the gym_wrapped environment. spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the default dtype for the tensors. An easy way to configure a custom mapping through Gin is to define a gin-configurable function that returns desired mapping and call it in your Gin config file, for example: `suite_gym.load.spec_dtype_map = @get_custom_mapping()`. gym_kwargs: The kwargs to pass to the Gym environment class. auto_reset: If True (default), reset the environment automatically after a terminal state is reached. Returns: A PyEnvironment instance. """ gym_kwargs = gym_kwargs if gym_kwargs else {} gym_spec = gym.spec(environment_name) gym_env = gym_spec.make(**gym_kwargs) if max_episode_steps is None and gym_spec.max_episode_steps is not None: max_episode_steps = gym_spec.max_episode_steps for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) env = AdversarialGymWrapper( gym_env, discount=discount, spec_dtype_map=spec_dtype_map, auto_reset=auto_reset, ) if max_episode_steps is not None and max_episode_steps > 0: env = wrappers.TimeLimit(env, max_episode_steps) for wrapper in env_wrappers: env = wrapper(env) return env
def test_extra_env_methods_work(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) self.assertIsNone(env.get_info()) env.reset() env.step(np.array(0, dtype=np.int32)) self.assertEqual({}, env.get_info())
def test_extra_env_methods_work(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) self.assertEqual(None, env.get_info()) env.reset() env.step(0) self.assertEqual({}, env.get_info())
def get_env(add_curiosity_reward=True): """Return a copy of the environment.""" env = VectorIncrementEnvironmentTFAgents(v_n=v_n, v_k=v_k, v_seed=v_seed, do_transform=do_transform) env = wrappers.TimeLimit(env, time_limit) if add_curiosity_reward: env = CuriosityWrapper(env, env_model, alpha=alpha) env = tf_py_environment.TFPyEnvironment(env) return env
def test_limit_duration_stops_after_duration(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) env.reset() env.step(np.array(0, dtype=np.int32)) time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_last()) self.assertNotEqual(None, time_step.discount) self.assertNotEqual(0.0, time_step.discount)
def test_episode_count_with_time_limit(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) env = wrappers.RunStats(env) env.reset() self.assertEqual(0, env.episodes) env.step(np.array(0, dtype=np.int32)) time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(time_step.is_last()) self.assertEqual(1, env.episodes)
def load( domain_name, task_name, task_kwargs=None, environment_kwargs=None, env_load_fn=suite.load, # use custom_suite.load for customized env action_repeat_wrapper=wrappers.ActionRepeat, action_repeat=1, frame_stack=4, episode_length=1000, actions_in_obs=True, rewards_in_obs=False, pixels_obs=True, # Render params grayscale=False, visualize_reward=False, render_kwargs=None): """Returns an environment from a domain name, task name.""" env = env_load_fn(domain_name, task_name, task_kwargs=task_kwargs, environment_kwargs=environment_kwargs, visualize_reward=visualize_reward) if pixels_obs: env = pixel_wrapper.Wrapper(env, pixels_only=False, render_kwargs=render_kwargs) env = dm_control_wrapper.DmControlWrapper(env, render_kwargs) if pixels_obs and grayscale: env = GrayscaleWrapper(env) if action_repeat > 1: env = action_repeat_wrapper(env, action_repeat) if pixels_obs: env = FrameStack(env, frame_stack, actions_in_obs, rewards_in_obs) else: env = FlattenState(env) # Adjust episode length based on action_repeat max_episode_steps = (episode_length + action_repeat - 1) // action_repeat # Apply a time limit wrapper at the end to properly trigger all reset() env = wrappers.TimeLimit(env, max_episode_steps) return env
def test_limit_duration_wrapped_env_forwards_calls(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 10) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4,), observation_spec.shape) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def test_duration_applied_after_episode_terminates_early(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 10000) # Episode 1 stepped until termination occurs. time_step = env.step(np.array(1, dtype=np.int32)) while not time_step.is_last(): time_step = env.step(np.array(1, dtype=np.int32)) self.assertTrue(time_step.is_last()) env._duration = 2 # Episode 2 short duration hits step limit. first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last())
def test_automatic_reset(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) # Episode 1 first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last()) # Episode 2 first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last())
def load_dm_env(env_name, frame_shape=(84, 84, 3), episode_length=1000, action_repeat=4, frame_stack=3, task_kwargs=None, render_kwargs=None, camera_kwargs=None, background_kwargs=None, color_kwargs=None, stack_within_repeat=False): """Returns an environment from a domain name, task name.""" domain_name, task_name = env_name.split('-') logging.info('Loading environment.') render_kwargs = render_kwargs or {} render_kwargs['width'] = frame_shape[0] render_kwargs['height'] = frame_shape[1] if 'camera_id' not in render_kwargs: render_kwargs['camera_id'] = 2 if domain_name == 'quadruped' else 0 if camera_kwargs and 'camera_id' not in camera_kwargs: camera_kwargs['camera_id'] = 2 if domain_name == 'quadruped' else 0 env = load_pixels(domain_name, task_name, task_kwargs=task_kwargs, render_kwargs=render_kwargs, camera_kwargs=camera_kwargs, background_kwargs=background_kwargs, color_kwargs=color_kwargs) env = FrameStackActionRepeatWrapper( env, action_repeat=action_repeat, stack_size=frame_stack, stack_within_repeat=stack_within_repeat) # Shorten episode length max_episode_steps = (episode_length + action_repeat - 1) // action_repeat env = wrappers.TimeLimit(env, max_episode_steps) return env
def env_load_fn(environment_name, max_episode_steps=None, resize_factor=1, gym_env_wrappers=(GoalConditionedPointWrapper,), terminate_on_timeout=False): """Loads the selected environment and wraps it with the specified wrappers. Args: environment_name: Name for the environment to load. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no timestep_limit set in the environment's spec. gym_env_wrappers: Iterable with references to wrapper classes to use directly on the gym environment. terminate_on_timeout: Whether to set done = True when the max episode steps is reached. Returns: A PyEnvironmentBase instance. """ gym_env = PointEnv(walls=environment_name, resize_factor=resize_factor) for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) env = gym_wrapper.GymWrapper( gym_env, discount=1.0, auto_reset=True, ) if max_episode_steps > 0: if terminate_on_timeout: env = wrappers.TimeLimit(env, max_episode_steps) else: env = NonTerminatingTimeLimit(env, max_episode_steps) return tf_py_environment.TFPyEnvironment(env)
eval_duration = 200 # @param # Split data into training and test set prices = pd.read_csv(CSV_PATH, parse_dates=True, index_col=0) train = prices[:date_split] test = prices[date_split:] # Create a feature list: # Push start date forward by features_length to have non-zero initial features # Cleanup logic here features_length = 20 train_features = Features(train, features_length) test_features = Features(test, features_length) # Create Environments train_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, train_features), duration=training_duration) eval_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, test_features), duration=eval_duration) test_py_env = wrappers.TimeLimit(TradingEnvironment(initial_balance, test_features), duration=len(test)-features_length-1) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) test_env = tf_py_environment.TFPyEnvironment(test_py_env) # Initialize Q Network q_net = q_network.QNetwork( train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0)
train_py_env = gym_wrapper.GymWrapper( ChangeRewardMountainCarEnv(), discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) eval_py_env = gym_wrapper.GymWrapper( ChangeRewardMountainCarEnv(), discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) train_py_env = wrappers.TimeLimit(train_py_env, duration=200) eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) RL_train(train_env, eval_env, fc_layer_params = (48,64,), name = '_train') """Set num_iterations to 50000+ will let agent converge to less than 110 steps""" iterations = range(len(returns)) plt.plot(iterations, returns) plt.ylabel('Average Return') plt.xlabel('Iterations') iterations = range(len(steps))
else: num_walls = 0 #Check if string from input is a valid digit before interpreting it (str.isdigit). Got help from https://stackoverflow.com/questions/1265665/how-can-i-check-if-a-string-represents-an-int-without-using-try-except if(str.isdigit(sys.argv[3])): num_agents = int(sys.argv[3]) else: num_agents = 0 #Check if string from input is a valid digit before interpreting it (str.isdigit). Got help from https://stackoverflow.com/questions/1265665/how-can-i-check-if-a-string-represents-an-int-without-using-try-except if(str.isdigit(sys.argv[4])): def_agents = int(sys.argv[4]) else: def_agents = 0 py_env = wrappers.TimeLimit(CTFEnv(grid_size, 512, num_walls, num_agents, def_agents), duration=100) env = tf_py_environment.TFPyEnvironment(py_env) policy = tf.saved_model.load(model_path) #Video of 5 simulations, written by Josh Gendein, borrowed and modified from https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial with imageio.get_writer(filename, fps=15) as video: for _ in range(5): time_step = env.reset() video.append_data(py_env.render()) while not time_step.is_last(): action_step = policy.action(time_step) time_step = env.step(action_step.action) video.append_data(py_env.render())
from tf_agents.experimental.train.utils import train_utils from tf_agents.metrics import py_metrics from tf_agents.networks import actor_distribution_network from tf_agents.policies import greedy_policy from tf_agents.policies import py_tf_eager_policy from tf_agents.policies import random_py_policy from tf_agents.replay_buffers import reverb_replay_buffer from tf_agents.replay_buffers import reverb_utils tempdir = tempfile.gettempdir() fc_layer_params = (100, ) importance_ratio_clipping lambda_value train_timed_env = wrappers.TimeLimit(ArmEnv(), 1000) eval_timed_env = wrappers.TimeLimit(ArmEnv(), 1000) train_env = tf_py_environment(train_timed_env) eval_env = tf_py_environment(eval_timed_env) observation_tensor_spec, action_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(train_env)) normalized_observation_tensor_spec = tf.nest.map_structure( lambda s: tf.TensorSpec(dtype=tf.float32, shape=s.shape, name=s.name), observation_tensor_spec) actor_net = actor_distribution_network.ActorDistributionNetwork( normalized_observation_tensor_spec, ...)
from ai.PastureEnvironment import PastureEnvironment from pasture.animal.sheep.Sheep import Sheep from pasture.animal.shepherd.Shepherd import Shepherd sheep_list = [Sheep(2, 2, 2), Sheep(4, 4, 2)] shepherd_list = [Shepherd(6, 6)] pasture_engine = PastureEngine(size=8, starting_shepherds_list=shepherd_list, starting_sheep_list=sheep_list, target=(1, 1)) pasture_environment = PastureEnvironment(pasture_engine) utils.validate_py_environment(pasture_environment, episodes=5) pasture_env_wrapped = wrappers.TimeLimit(pasture_environment, duration=15) print(pasture_env_wrapped) train_tf_env = tf_py_environment.TFPyEnvironment(pasture_env_wrapped) print(train_tf_env) eval_tf_env = tf_py_environment.TFPyEnvironment(pasture_env_wrapped) print(eval_tf_env) fc_layer_params = [32, 64, 128] q_net = q_network.QNetwork( train_tf_env.observation_spec(), # input train_tf_env.action_spec(), # output fc_layer_params=fc_layer_params # layerz )
import numpy as np import rospy from tf_agents.environments import py_environment from tf_agents.environments import tf_environment from tf_agents.environments import tf_py_environment from tf_agents.environments import utils from tf_agents.specs import array_spec from tf_agents.environments import wrappers from tf_agents.environments import suite_gym from tf_agents.trajectories import time_step as ts from arm_pyenv import ArmEnv # source devel/setup.bash # roslaunch arm_bringup sim_bringup.launch world:=empty rospy.init_node("test") tf.compat.v1.enable_v2_behavior() environment = ArmEnv() timed_env = wrappers.TimeLimit( environment, 900 ) utils.validate_py_environment(timed_env, episodes=5) print('action_spec:', environment.action_spec()) print('time_step_spec:', environment.time_step_spec()) print('time_step_spec.observation:', environment.time_step_spec().observation) print('time_step_spec.step_type:', environment.time_step_spec().step_type) print('time_step_spec.discount:', environment.time_step_spec().discount) print('time_step_spec.reward:', environment.time_step_spec().reward)
tf_env = tf_py_environment.TFPyEnvironment(env) print(isinstance(tf_env, tf_environment.TFEnvironment)) print("TimeStep Specs:", tf_env.time_step_spec()) print("Action Specs:", tf_env.action_spec()) #Usually two environments are instantiated: # one for training and one for evaluation. # First we will load the Gridworld environments # into a TimeLimit Wrapper which terminates the game # if 10 steps are reached. # The results are then wrapped in the TF environment handlers. train_py_env = wrappers.TimeLimit(env, duration=100) #change duration later if neded eval_py_env = wrappers.TimeLimit(env, duration=100) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) """The DQN agent can be used in any environment which has a discrete action space. At the heart of a DQN Agent is a QNetwork, a neural network model that can learn to predict QValues (expected returns) for all actions, given an observation from the environment. Use tf_agents.networks.q_network to create a QNetwork, passing in the observation_spec, action_spec, and a tuple describing the number and size of the model's hidden layers. """ # Next, setup an agent, a learning model, replay buffer, driver. # Connect them together #right now, creating a network with single hidden layer of 100 nodes fc_layer_params = (100,)
if action == 2: #left if col - 1 >= 0: self._state[1] -= 1 if action == 3: #right if col + 1 < 6: self._state[1] += 1 def game_over(self): row, col, frow, fcol = self._state[0], self._state[1], self._state[ 2], self._state[3] return row == frow and col == fcol if __name__ == '__main__': env = GridWorldEnv() utils.validate_py_environment(env, episodes=5) tl_env = wrappers.TimeLimit(env, duration=50) time_step = tl_env.reset() print(time_step) rewards = time_step.reward for i in range(100): action = np.random.choice([0, 1, 2, 3]) time_step = tl_env.step(action) print(time_step) rewards += time_step.reward print(rewards)
num_iterations = 40000 # @param initial_collect_steps = 1000 # @param collect_steps_per_iteration = 1 # @param replay_buffer_capacity = 100000 # @param fc_layer_params = (100, ) batch_size = 128 # @param learning_rate = 1e-5 # @param log_interval = 200 # @param num_eval_episodes = 2 # @param eval_interval = 1000 # @param train_py_env = wrappers.TimeLimit(ClusterEnv(), duration=100) eval_py_env = wrappers.TimeLimit(ClusterEnv(), duration=100) # train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) # eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0)
batch_size = 64 # @param {type:"integer"} learning_rate = 1e-5 # @param {type:"number"} log_interval = 200 # @param {type:"integer"} num_eval_episodes = 10 # @param {type:"integer"} eval_interval = 1000 # @param {type:"integer"} #Parameters for agents and walls grid_size = int(sys.argv[1]) num_walls = int(sys.argv[2]) num_agents = int(sys.argv[3]) def_agents = int(sys.argv[4]) c = CTFEnv(grid_size, 512, num_walls, num_agents, def_agents) #Training environment, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial #Simulation will last 200 steps train_py_env = wrappers.TimeLimit(c, duration=200) eval_py_env = wrappers.TimeLimit(c, duration=200) #Training environment, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) #Q Network, written by Josh Gendein, borrowed and modified from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial fc_layer_params = (2000, ) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) #Q Network Optimizer, written by Josh Gendein, borrowed from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0)
num_iterations = 10000 # @param initial_collect_steps = 1000 # @param collect_steps_per_iteration = 1 # @param replay_buffer_capacity = 100000 # @param fc_layer_params = (100, ) batch_size = 128 # @param learning_rate = 1e-5 # @param log_interval = 200 # @param num_eval_episodes = 2 # @param eval_interval = 1000 # @param train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(),
def _load_dm_env(domain_name, task_name, pixels, action_repeat, max_episode_steps=None, obs_type='pixels', distractor=False): """Load a Deepmind control suite environment.""" try: if not pixels: env = suite_dm_control.load(domain_name=domain_name, task_name=task_name) if action_repeat > 1: env = wrappers.ActionRepeat(env, action_repeat) else: def wrap_repeat(env): return ActionRepeatDMWrapper(env, action_repeat) camera_id = 2 if domain_name == 'quadruped' else 0 pixels_only = obs_type == 'pixels' if distractor: render_kwargs = dict(width=84, height=84, camera_id=camera_id) env = distractor_suite.load( domain_name, task_name, difficulty='hard', dynamic=False, background_dataset_path='DAVIS/JPEGImages/480p/', task_kwargs={}, environment_kwargs={}, render_kwargs=render_kwargs, visualize_reward=False, env_state_wrappers=[wrap_repeat]) # env = wrap_repeat(env) # env = suite.wrappers.pixels.Wrapper( # env, # pixels_only=pixels_only, # render_kwargs=render_kwargs, # observation_key=obs_type) env = dm_control_wrapper.DmControlWrapper(env, render_kwargs) else: env = suite_dm_control.load_pixels( domain_name=domain_name, task_name=task_name, render_kwargs=dict(width=84, height=84, camera_id=camera_id), env_state_wrappers=[wrap_repeat], observation_key=obs_type, pixels_only=pixels_only) if action_repeat > 1 and max_episode_steps is not None: # Shorten episode length. max_episode_steps = (max_episode_steps + action_repeat - 1) // action_repeat env = wrappers.TimeLimit(env, max_episode_steps) return env except ValueError as e: logging.warning( 'cannot instantiate dm env: domain_name=%s, task_name=%s', domain_name, task_name) logging.warning('Supported domains and tasks: %s', str({ key: list(val.SUITE.keys()) for key, val in suite._DOMAINS.items() })) # pylint: disable=protected-access raise e