def make(env_name, frame_stack, action_repeat, seed): domain, task = split_env_name(env_name) if domain == 'manip': env = manipulation.load(f'{task}_vision', seed=seed) else: env = suite.load(domain, task, task_kwargs={'random': seed}, visualize_reward=False) # apply action repeat and scaling env = ActionRepeatWrapper(env, action_repeat) env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) # flatten features env = FlattenObservationWrapper(env) if domain != 'manip': # per dreamer: https://github.com/danijar/dreamer/blob/02f0210f5991c7710826ca7881f19c64a012290c/wrappers.py#L26 camera_id = 2 if domain == 'quadruped' else 0 render_kwargs = {'height': 84, 'width': 84, 'camera_id': camera_id} env = pixels.Wrapper(env, pixels_only=False, render_kwargs=render_kwargs) env = FrameStackWrapper(env, frame_stack) action_spec = env.action_spec() assert np.all(action_spec.minimum >= -1.0) assert np.all(action_spec.maximum <= +1.0) return env
def test_invalid_action_spec_type(self): action_spec = [make_action_spec()] * 2 env = make_mock_env(action_spec=action_spec) with self.assertRaisesWithLiteralMatch( ValueError, action_scale._ACTION_SPEC_MUST_BE_BOUNDED_ARRAY.format(action_spec)): action_scale.Wrapper(env, minimum=0, maximum=1)
def test_non_finite_bounds(self, name, bounds): kwargs = {'minimum': np.r_[-1.], 'maximum': np.r_[1.]} kwargs[name] = bounds env = make_mock_env(action_spec=make_action_spec()) with self.assertRaisesWithLiteralMatch( ValueError, action_scale._MUST_BE_FINITE.format(name=name, bounds=bounds)): action_scale.Wrapper(env, **kwargs)
def test_method_delegated_to_underlying_env(self, method_name): env = make_mock_env(action_spec=make_action_spec()) wrapped_env = action_scale.Wrapper(env, minimum=0, maximum=1) env_method = getattr(env, method_name) wrapper_method = getattr(wrapped_env, method_name) out = wrapper_method() env_method.assert_called_once_with() self.assertIs(out, env_method())
def test_correct_action_spec(self, minimum, maximum): original_action_spec = make_action_spec( lower=np.r_[-2., -2.], upper=np.r_[2., 2.]) env = make_mock_env(action_spec=original_action_spec) wrapped_env = action_scale.Wrapper(env, minimum=minimum, maximum=maximum) new_action_spec = wrapped_env.action_spec() np.testing.assert_array_equal(new_action_spec.minimum, minimum) np.testing.assert_array_equal(new_action_spec.maximum, maximum)
def test_invalid_bounds_shape(self, name, bounds): shape = (2,) kwargs = {'minimum': np.zeros(shape), 'maximum': np.ones(shape)} kwargs[name] = bounds action_spec = make_action_spec(lower=[-1, -1], upper=[2, 3]) env = make_mock_env(action_spec=action_spec) with self.assertRaisesWithLiteralMatch( ValueError, action_scale._MUST_BROADCAST.format( name=name, bounds=bounds, shape=shape)): action_scale.Wrapper(env, **kwargs)
def test_step(self, minimum, maximum, scaled_minimum, scaled_maximum): action_spec = make_action_spec(lower=minimum, upper=maximum) env = make_mock_env(action_spec=action_spec) wrapped_env = action_scale.Wrapper( env, minimum=scaled_minimum, maximum=scaled_maximum) time_step = wrapped_env.step(scaled_minimum) self.assertStepCalledOnceWithCorrectAction(env, minimum) self.assertIs(time_step, env.step(minimum)) env.reset_mock() time_step = wrapped_env.step(scaled_maximum) self.assertStepCalledOnceWithCorrectAction(env, maximum) self.assertIs(time_step, env.step(maximum))
def make_meta(env_name, episode_length, seed): assert env_name == 'cartpole_balance' envs = [ multi_task_cartpole.balance_v1(random=seed), multi_task_cartpole.balance_v2(random=seed), multi_task_cartpole.balance_v3(random=seed), multi_task_cartpole.balance_v4(random=seed), multi_task_cartpole.balance_v5(random=seed), ] envs = [ action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) for env in envs ] envs = [FlattenObservationWrapper(env) for env in envs] envs = [TimeLimitWrapper(env, episode_length) for env in envs] env = MetaEnv(envs) #env = TaskIdWrapper(env) return env
def make(env_name, seed): if env_name == 'ball_in_cup_catch': domain_name = 'ball_in_cup' task_name = 'catch' else: domain_name = env_name.split('_')[0] task_name = '_'.join(env_name.split('_')[1:]) env = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs={'random': seed}, visualize_reward=False) env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) env = FlattenObservationWrapper(env) action_spec = env.action_spec() assert np.all(action_spec.minimum >= -1.0) assert np.all(action_spec.maximum <= +1.0) return env
def __init__(self, domain, task, *args, env=None, normalize=True, observation_keys=(), goal_keys=(), unwrap_time_limit=True, pixel_wrapper_kwargs=None, **kwargs): assert not args, ( "Gym environments don't support args. Use kwargs instead.") self.normalize = normalize self.unwrap_time_limit = unwrap_time_limit super(DmControlAdapter, self).__init__(domain, task, *args, goal_keys=goal_keys, **kwargs) if env is None: assert (domain is not None and task is not None), (domain, task) env = suite.load( domain_name=domain, task_name=task, task_kwargs=kwargs # TODO(hartikainen): Figure out how to pass kwargs to this guy. # Need to split into `task_kwargs`, `environment_kwargs`, and # `visualize_reward` bool. Check the suite.load(.) in: # https://github.com/deepmind/dm_control/blob/master/dm_control/suite/__init__.py ) self._env_kwargs = kwargs else: assert not kwargs assert domain is None and task is None, (domain, task) if normalize: if (np.any(env.action_spec().minimum != -1) or np.any(env.action_spec().maximum != 1)): env = action_scale.Wrapper(env, minimum=-1.0, maximum=1.0) np.testing.assert_equal(env.action_spec().minimum, -1) np.testing.assert_equal(env.action_spec().maximum, 1) if pixel_wrapper_kwargs is not None: env = pixels.Wrapper(env, **pixel_wrapper_kwargs) self._env = env assert isinstance(env.observation_spec(), OrderedDict) self.observation_keys = (observation_keys or tuple(env.observation_spec().keys())) observation_space = convert_dm_control_to_gym_space( env.observation_spec()) self._observation_space = type(observation_space)([ (name, copy.deepcopy(space)) for name, space in observation_space.spaces.items() if name in self.observation_keys + self.goal_keys ]) action_space = convert_dm_control_to_gym_space(self._env.action_spec()) if len(action_space.shape) > 1: raise NotImplementedError( "Shape of the action space ({}) is not flat, make sure to" " check the implemenation.".format(action_space)) self._action_space = action_space