def test_mixed_actor_distributions(self, lstm_hidden_size):
        action_spec = dict(discrete=BoundedTensorSpec((), dtype="int64"),
                           continuous=BoundedTensorSpec((3, )))

        network_ctor, state = self._init(lstm_hidden_size)

        actor_dist_net = network_ctor(
            self._input_spec,
            action_spec,
            input_preprocessors=self._input_preprocessors,
            preprocessing_combiner=self._preprocessing_combiner,
            conv_layer_params=self._conv_layer_params)

        act_dist, state = actor_dist_net(self._image, state)

        self.assertTrue(
            isinstance(actor_dist_net.output_spec["discrete"],
                       DistributionSpec))
        self.assertTrue(
            isinstance(actor_dist_net.output_spec["continuous"],
                       DistributionSpec))

        self.assertTrue(isinstance(act_dist["discrete"], td.Categorical))
        self.assertTrue(isinstance(act_dist["continuous"].base_dist,
                                   td.Normal))

        if lstm_hidden_size is None:
            self.assertEqual(state, ())
        else:
            self.assertEqual(len(state), len(lstm_hidden_size))
Example #2
0
 def testIntegerSamplesIncludeUpperBound(self, dtype):
     if dtype.is_floating_point:  # Only test on integer dtypes.
         return
     spec = BoundedTensorSpec(self._shape, dtype, 3, 3)
     sample = spec.sample()
     self.assertEqual(sample.shape, self._shape)
     self.assertTrue(torch.all(sample == 3))
Example #3
0
    def test_sac_algorithm_init(self):
        observation_spec = BoundedTensorSpec((10, ))
        discrete_action_spec = BoundedTensorSpec((), dtype='int64')
        continuous_action_spec = [
            BoundedTensorSpec((3, )),
            BoundedTensorSpec((10, ))
        ]

        universal_q_network = partial(
            QNetwork, preprocessing_combiner=NestConcat())
        critic_network = partial(
            CriticNetwork, action_preprocessing_combiner=NestConcat())

        # q_network instead of critic_network is needed
        self.assertRaises(
            AssertionError,
            SacAlgorithm,
            observation_spec=observation_spec,
            action_spec=discrete_action_spec,
            q_network_cls=None)

        sac = SacAlgorithm(
            observation_spec=observation_spec,
            action_spec=discrete_action_spec,
            q_network_cls=QNetwork)
        self.assertEqual(sac._act_type, SacActionType.Discrete)
        self.assertEqual(sac.train_state_spec.actor, ())
        self.assertEqual(sac.train_state_spec.action.actor_network, ())

        # critic_network instead of q_network is needed
        self.assertRaises(
            AssertionError,
            SacAlgorithm,
            observation_spec=observation_spec,
            action_spec=continuous_action_spec,
            critic_network_cls=None)

        sac = SacAlgorithm(
            observation_spec=observation_spec,
            action_spec=continuous_action_spec,
            critic_network_cls=critic_network)
        self.assertEqual(sac._act_type, SacActionType.Continuous)
        self.assertEqual(sac.train_state_spec.action.critic, ())

        # action_spec order is incorrect
        self.assertRaises(
            AssertionError,
            SacAlgorithm,
            observation_spec=observation_spec,
            action_spec=(continuous_action_spec, discrete_action_spec),
            q_network_cls=universal_q_network)

        sac = SacAlgorithm(
            observation_spec=observation_spec,
            action_spec=(discrete_action_spec, continuous_action_spec),
            q_network_cls=universal_q_network)
        self.assertEqual(sac._act_type, SacActionType.Mixed)
        self.assertEqual(sac.train_state_spec.actor, ())
Example #4
0
def tensor_spec_from_gym_space(space, simplify_box_bounds=True):
    """
    Mostly adapted from ``spec_from_gym_space`` in
    ``tf_agents.environments.gym_wrapper``. Instead of using a ``dtype_map``
    as default data types, it always uses dtypes of gym spaces since gym is now
    updated to support this.
    """

    # We try to simplify redundant arrays to make logging and debugging less
    # verbose and easier to read since the printed spec bounds may be large.
    def try_simplify_array_to_value(np_array):
        """If given numpy array has all the same values, returns that value."""
        first_value = np_array.item(0)
        if np.all(np_array == first_value):
            return np.array(first_value, dtype=np_array.dtype)
        else:
            return np_array

    if isinstance(space, gym.spaces.Discrete):
        # Discrete spaces span the set {0, 1, ... , n-1} while Bounded Array specs
        # are inclusive on their bounds.
        maximum = space.n - 1
        return BoundedTensorSpec(shape=(),
                                 dtype=space.dtype.name,
                                 minimum=0,
                                 maximum=maximum)
    elif isinstance(space, gym.spaces.MultiDiscrete):
        maximum = try_simplify_array_to_value(
            np.asarray(space.nvec - 1, dtype=space.dtype))
        return BoundedTensorSpec(shape=space.shape,
                                 dtype=space.dtype.name,
                                 minimum=0,
                                 maximum=maximum)
    elif isinstance(space, gym.spaces.MultiBinary):
        shape = (space.n, )
        return BoundedTensorSpec(shape=shape,
                                 dtype=space.dtype.name,
                                 minimum=0,
                                 maximum=1)
    elif isinstance(space, gym.spaces.Box):
        minimum = np.asarray(space.low, dtype=space.dtype)
        maximum = np.asarray(space.high, dtype=space.dtype)
        if simplify_box_bounds:
            minimum = try_simplify_array_to_value(minimum)
            maximum = try_simplify_array_to_value(maximum)
        return BoundedTensorSpec(shape=space.shape,
                                 dtype=space.dtype.name,
                                 minimum=minimum,
                                 maximum=maximum)
    elif isinstance(space, gym.spaces.Tuple):
        return tuple([tensor_spec_from_gym_space(s) for s in space.spaces])
    elif isinstance(space, gym.spaces.Dict):
        return collections.OrderedDict([(key, tensor_spec_from_gym_space(s))
                                        for key, s in space.spaces.items()])
    else:
        raise ValueError(
            'The gym space {} is currently not supported.'.format(space))
Example #5
0
 def testIntegerSamplesExcludeMaxOfDtype(self, dtype):
     # Exclude non integer types and uint8 (has special sampling logic).
     if dtype.is_floating_point or dtype == torch.uint8:
         return
     info = np.iinfo(torch_dtype_to_str(dtype))
     spec = BoundedTensorSpec(self._shape, dtype, info.max - 1,
                              info.max - 1)
     sample = spec.sample(outer_dims=(1, ))
     self.assertEqual(sample.shape, (1, ) + self._shape)
     self.assertTrue(torch.all(sample == info.max - 1))
Example #6
0
    def testResetSavesCurrentTimeStep(self):
        obs_spec = BoundedTensorSpec((1, ), torch.int32)
        action_spec = BoundedTensorSpec((1, ), torch.int64)

        random_env = RandomAlfEnvironment(observation_spec=obs_spec,
                                          action_spec=action_spec)

        time_step = random_env.reset()
        current_time_step = random_env.current_time_step()
        nest.map_structure(self.assertEqual, time_step, current_time_step)
Example #7
0
 def _create_action_spec(act_type):
     if act_type == ActionType.Discrete:
         return BoundedTensorSpec(shape=(),
                                  dtype=torch.int64,
                                  minimum=0,
                                  maximum=1)
     else:
         return BoundedTensorSpec(shape=(1, ),
                                  dtype=torch.float32,
                                  minimum=[0],
                                  maximum=[1])
Example #8
0
 def testBatchSize(self):
     batch_size = 3
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                batch_size=batch_size)
     time_step = env.step(torch.tensor(0, dtype=torch.int64))
     self.assertEqual(time_step.observation.shape, (3, 2, 3))
     self.assertEqual(time_step.reward.shape[0], batch_size)
     self.assertEqual(time_step.discount.shape[0], batch_size)
Example #9
0
 def testCustomRewardFn(self):
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     batch_size = 3
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.ones(batch_size),
                                batch_size=batch_size)
     env._done = False
     env.reset()
     action = torch.ones(batch_size)
     time_step = env.step(action)
     self.assertSequenceAlmostEqual([1.0] * 3, time_step.reward)
Example #10
0
 def testRewardCheckerBatchSizeOne(self):
     # Ensure batch size 1 with scalar reward works
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.array([1.0]),
                                batch_size=1)
     env._done = False
     env.reset()
     action = torch.tensor([0], dtype=torch.int64)
     time_step = env.step(action)
     self.assertEqual(time_step.reward, 1.0)
Example #11
0
    def testBoundedTensorSpecSample(self, dtype):
        if not dtype.is_floating_point:
            return
        # minimum and maximum shape broadcasting
        spec = BoundedTensorSpec(self._shape, dtype, (0, ) * 30, 3)
        sample = spec.sample()
        self.assertEqual(self._shape, sample.shape)
        self.assertTrue(torch.all(sample <= 3))
        self.assertTrue(torch.all(0 <= sample))

        # last minimum is greater than last maximum
        self.assertRaises(AssertionError, BoundedTensorSpec, self._shape,
                          dtype, (0, ) * 29 + (2, ), (1, ) * 30)
Example #12
0
    def testRendersImage(self):
        action_spec = BoundedTensorSpec((1, ), torch.int64, -10, 10)
        observation_spec = BoundedTensorSpec((1, ), torch.int32, -10, 10)
        env = RandomAlfEnvironment(observation_spec,
                                   action_spec,
                                   render_size=(4, 4, 3))

        env.reset()
        img = env.render()

        self.assertTrue(np.all(img < 256))
        self.assertTrue(np.all(img >= 0))
        self.assertEqual((4, 4, 3), img.shape)
        self.assertEqual(np.uint8, img.dtype)
Example #13
0
 def testRewardCheckerSizeMismatch(self):
     # Ensure custom scalar reward with batch_size greater than 1 raises
     # ValueError
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.array([1.0]),
                                batch_size=5)
     env.reset()
     env._done = False
     action = torch.tensor(0, dtype=torch.int64)
     with self.assertRaises(ValueError):
         env.step(action)
Example #14
0
    def testRewardFnCalled(self):
        def reward_fn(unused_step_type, action, unused_observation):
            return action

        action_spec = BoundedTensorSpec((1, ), torch.int64, -10, 10)
        observation_spec = BoundedTensorSpec((1, ), torch.int32, -10, 10)
        env = RandomAlfEnvironment(observation_spec,
                                   action_spec,
                                   reward_fn=reward_fn)

        action = np.array(1, dtype=np.int64)
        time_step = env.step(action)  # No reward in first time_step
        self.assertEqual(np.zeros((), dtype=np.float32), time_step.reward)
        time_step = env.step(action)
        self.assertEqual(np.ones((), dtype=np.float32), time_step.reward)
    def test_continuous_actor_distribution(self, lstm_hidden_size):
        action_spec = BoundedTensorSpec((3, ), torch.float32)

        network_ctor, state = self._init(lstm_hidden_size)

        actor_dist_net = network_ctor(
            self._input_spec,
            action_spec,
            input_preprocessors=self._input_preprocessors,
            preprocessing_combiner=self._preprocessing_combiner,
            conv_layer_params=self._conv_layer_params,
            continuous_projection_net_ctor=functools.partial(
                NormalProjectionNetwork, scale_distribution=True))
        act_dist, _ = actor_dist_net(self._image, state)
        actions = act_dist.sample((100, ))

        self.assertTrue(
            isinstance(actor_dist_net.output_spec, DistributionSpec))

        # (num_samples, batch_size, action_spec_shape)
        self.assertEqual(actions.shape, (100, 1) + action_spec.shape)

        self.assertTrue(
            torch.all(actions >= torch.as_tensor(action_spec.minimum)))
        self.assertTrue(
            torch.all(actions <= torch.as_tensor(action_spec.maximum)))
    def test_discrete_actor_distribution(self, lstm_hidden_size):
        action_spec = TensorSpec((), torch.int32)
        network_ctor, state = self._init(lstm_hidden_size)

        # action_spec is not bounded
        self.assertRaises(AssertionError,
                          network_ctor,
                          self._input_spec,
                          action_spec,
                          conv_layer_params=self._conv_layer_params)

        action_spec = BoundedTensorSpec((), torch.int32)
        actor_dist_net = network_ctor(
            self._input_spec,
            action_spec,
            input_preprocessors=self._input_preprocessors,
            preprocessing_combiner=self._preprocessing_combiner,
            conv_layer_params=self._conv_layer_params)

        act_dist, _ = actor_dist_net(self._image, state)
        actions = act_dist.sample((100, ))

        self.assertTrue(
            isinstance(actor_dist_net.output_spec, DistributionSpec))

        # (num_samples, batch_size)
        self.assertEqual(actions.shape, (100, 1))

        self.assertTrue(
            torch.all(actions >= torch.as_tensor(action_spec.minimum)))
        self.assertTrue(
            torch.all(actions <= torch.as_tensor(action_spec.maximum)))
Example #17
0
    def test_mixed_actions(self, net_ctor):
        obs_spec = TensorSpec((20, ))
        action_spec = dict(x=BoundedTensorSpec((), dtype='int64'),
                           y=BoundedTensorSpec((3, )))

        input_preprocessors = dict(x=EmbeddingPreprocessor(action_spec['x'],
                                                           embedding_dim=10),
                                   y=None)

        net_ctor = functools.partial(
            net_ctor, action_input_processors=input_preprocessors)

        # doesn't support mixed actions
        self.assertRaises(AssertionError, net_ctor, (obs_spec, action_spec))

        # ... unless a combiner is specified
        net_ctor((obs_spec, action_spec),
                 action_preprocessing_combiner=NestConcat())
Example #18
0
    def testEnvMaxDuration(self, max_duration):
        obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
        action_spec = BoundedTensorSpec([], torch.int32)
        env = RandomAlfEnvironment(obs_spec,
                                   action_spec,
                                   episode_end_probability=0.1,
                                   max_duration=max_duration)
        num_episodes = 100

        action = torch.tensor(0, dtype=torch.int64)
        for _ in range(num_episodes):
            time_step = env.step(action)
            self.assertTrue(time_step.is_first())
            num_steps = 0
            while not time_step.is_last():
                time_step = env.step(action)
                num_steps += 1
            self.assertLessEqual(num_steps, max_duration)
Example #19
0
    def test_agent_steps(self):
        batch_size = 1
        observation_spec = TensorSpec((10, ))
        action_spec = BoundedTensorSpec((), dtype='int64')
        time_step = TimeStep(
            observation=observation_spec.zeros(outer_dims=(batch_size, )),
            prev_action=action_spec.zeros(outer_dims=(batch_size, )))

        actor_net = functools.partial(ActorDistributionNetwork,
                                      fc_layer_params=(100, ))
        value_net = functools.partial(ValueNetwork, fc_layer_params=(100, ))

        # TODO: add a goal generator and an entropy target algorithm once they
        # are implemented.
        agent = Agent(observation_spec=observation_spec,
                      action_spec=action_spec,
                      rl_algorithm_cls=functools.partial(
                          ActorCriticAlgorithm,
                          actor_network_ctor=actor_net,
                          value_network_ctor=value_net),
                      intrinsic_reward_module=ICMAlgorithm(
                          action_spec=action_spec,
                          observation_spec=observation_spec))

        predict_state = agent.get_initial_predict_state(batch_size)
        rollout_state = agent.get_initial_rollout_state(batch_size)
        train_state = agent.get_initial_train_state(batch_size)

        pred_step = agent.predict_step(time_step,
                                       predict_state,
                                       epsilon_greedy=0.1)
        self.assertEqual(pred_step.state.irm, ())

        rollout_step = agent.rollout_step(time_step, rollout_state)
        self.assertNotEqual(rollout_step.state.irm, ())

        exp = make_experience(time_step, rollout_step, rollout_state)

        train_step = agent.train_step(exp, train_state)
        self.assertNotEqual(train_step.state.irm, ())

        self.assertTensorEqual(rollout_step.state.irm, train_step.state.irm)
Example #20
0
    def testEnvResetAutomatically(self):
        obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
        action_spec = BoundedTensorSpec([], torch.int32)
        env = RandomAlfEnvironment(obs_spec, action_spec)

        action = torch.tensor(0, dtype=torch.int64)
        time_step = env.step(action)
        self.assertTrue(np.all(time_step.observation >= -10))
        self.assertTrue(np.all(time_step.observation <= 10))
        self.assertTrue(time_step.is_first())

        while not time_step.is_last():
            time_step = env.step(action)
            self.assertTrue(np.all(time_step.observation >= -10))
            self.assertTrue(np.all(time_step.observation <= 10))

        time_step = env.step(action)
        self.assertTrue(np.all(time_step.observation >= -10))
        self.assertTrue(np.all(time_step.observation <= 10))
        self.assertTrue(time_step.is_first())
Example #21
0
    def test_discrete_action(self):
        action_spec = BoundedTensorSpec((),
                                        dtype=torch.int64,
                                        minimum=0,
                                        maximum=3)
        alg = ICMAlgorithm(action_spec=action_spec,
                           observation_spec=self._input_tensor_spec,
                           hidden_size=self._hidden_size)
        state = self._input_tensor_spec.zeros(outer_dims=(1, ))

        alg_step = alg.train_step(
            self._time_step._replace(prev_action=action_spec.zeros(
                outer_dims=(1, ))), state)

        # the inverse net should predict a uniform distribution
        self.assertTensorClose(
            torch.sum(alg_step.info.loss.extra['inverse_loss']),
            torch.as_tensor(
                math.log(action_spec.maximum - action_spec.minimum + 1)),
            epsilon=1e-4)
Example #22
0
    def test_discrete_skill_loss(self):
        skill_spec = BoundedTensorSpec((),
                                       dtype=torch.int64,
                                       minimum=0,
                                       maximum=3)
        alg = DIAYNAlgorithm(skill_spec=skill_spec,
                             encoding_net=self._encoding_net)
        skill = state = torch.nn.functional.one_hot(
            skill_spec.zeros(outer_dims=(1, )),
            int(skill_spec.maximum - skill_spec.minimum + 1)).to(torch.float32)

        alg_step = alg.train_step(
            self._time_step._replace(
                observation=[self._time_step.observation, skill]), state)

        # the discriminator should predict a uniform distribution
        self.assertTensorClose(torch.sum(alg_step.info.loss),
                               torch.as_tensor(
                                   math.log(skill_spec.maximum -
                                            skill_spec.minimum + 1)),
                               epsilon=1e-4)
Example #23
0
    def test_discrete_action(self, net_ctor):
        obs_spec = TensorSpec((20, ))
        action_spec = BoundedTensorSpec((), dtype='int64')

        # doesn't support discrete action spec ...
        self.assertRaises(AssertionError, net_ctor, (obs_spec, action_spec))

        # ... unless an preprocessor is specified
        net_ctor(
            (obs_spec, action_spec),
            action_input_processors=EmbeddingPreprocessor(action_spec,
                                                          embedding_dim=10))
Example #24
0
def get_low_rl_input_spec(observation_spec, action_spec, num_steps_per_skill,
                          skill_spec):
    assert observation_spec.ndim == 1 and action_spec.ndim == 1
    concat_observation_spec = TensorSpec(
        (num_steps_per_skill * observation_spec.shape[0], ))
    concat_action_spec = TensorSpec(
        (num_steps_per_skill * action_spec.shape[0], ))
    traj_spec = SubTrajectory(observation=concat_observation_spec,
                              prev_action=concat_action_spec)
    step_spec = step_spec = BoundedTensorSpec(shape=(),
                                              maximum=num_steps_per_skill,
                                              dtype='int64')
    return alf.nest.flatten(traj_spec) + [step_spec, skill_spec]
Example #25
0
    def test_uniform_projection_net(self):
        """A zero-weight net generates uniform actions."""
        input_spec = TensorSpec((10, ), torch.float32)
        embedding = input_spec.ones(outer_dims=(1, ))

        net = CategoricalProjectionNetwork(input_size=input_spec.shape[0],
                                           action_spec=BoundedTensorSpec(
                                               (1, ), minimum=0, maximum=4),
                                           logits_init_output_factor=0)
        dist, _ = net(embedding)
        self.assertTrue(isinstance(net.output_spec, DistributionSpec))
        self.assertEqual(dist.batch_shape, (1, ))
        self.assertEqual(dist.base_dist.batch_shape, (1, 1))
        self.assertTrue(torch.all(dist.base_dist.probs == 0.2))
Example #26
0
    def test_same_actin_prior_actor(self):
        action_spec = dict(a=BoundedTensorSpec(shape=()),
                           b=BoundedTensorSpec((3, ),
                                               minimum=(-1, 0, -2),
                                               maximum=(2, 2, 3)),
                           c=BoundedTensorSpec((2, 3), minimum=-1, maximum=1))
        actor = SameActionPriorActor(observation_spec=(),
                                     action_spec=action_spec)
        batch = TimeStep(step_type=torch.tensor([StepType.FIRST,
                                                 StepType.MID]),
                         prev_action=dict(a=torch.tensor([0., 1.]),
                                          b=torch.tensor([[-1., 0., -2.],
                                                          [2., 2., 3.]]),
                                          c=action_spec['c'].sample((2, ))))
        alg_step = actor.predict_step(batch, ())
        self.assertAlmostEqual(
            alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[0],
            alg_step.output['a'].log_prob(torch.tensor([1., 1.]))[0],
            delta=1e-6)
        self.assertAlmostEqual(
            alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[1],
            alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[0] +
            math.log(0.1),
            delta=1e-6)

        self.assertAlmostEqual(alg_step.output['b'].log_prob(
            torch.tensor([[-1., 0., -2.]] * 2))[0],
                               alg_step.output['b'].log_prob(
                                   torch.tensor([[2., 2., 3.]] * 2))[0],
                               delta=1e-6)

        self.assertAlmostEqual(alg_step.output['b'].log_prob(
            torch.tensor([[-1., 0., -2.]] * 2))[1],
                               alg_step.output['b'].log_prob(
                                   torch.tensor([[-1., 0., -2.]] * 2))[0] +
                               3 * math.log(0.1),
                               delta=1e-6)
Example #27
0
    def test_close_uniform_projection_net(self):
        """A random-weight net generates close-uniform actions on average."""
        input_spec = TensorSpec((10, ), torch.float32)
        embeddings = input_spec.ones(outer_dims=(100, ))

        net = CategoricalProjectionNetwork(input_size=input_spec.shape[0],
                                           action_spec=BoundedTensorSpec(
                                               (3, 2), minimum=0, maximum=4),
                                           logits_init_output_factor=1.0)
        dists, _ = net(embeddings)
        self.assertEqual(dists.batch_shape, (100, ))
        self.assertEqual(dists.base_dist.batch_shape, (100, 3, 2))
        self.assertTrue(dists.base_dist.probs.std() > 0)
        self.assertTrue(
            torch.isclose(dists.base_dist.probs.mean(), torch.as_tensor(0.2)))
Example #28
0
    def test_uniform_prior_actor(self):
        action_spec = dict(a=BoundedTensorSpec(shape=()),
                           b=BoundedTensorSpec((3, ),
                                               minimum=(-1, 0, -2),
                                               maximum=(2, 2, 3)),
                           c=BoundedTensorSpec((2, 3), minimum=-1, maximum=1))
        actor = UniformPriorActor(observation_spec=(), action_spec=action_spec)
        batch = TimeStep(step_type=torch.tensor([StepType.FIRST,
                                                 StepType.MID]),
                         prev_action=dict(a=torch.tensor([0., 1.]),
                                          b=torch.tensor([[-1., 0., -2.],
                                                          [2., 2., 3.]]),
                                          c=action_spec['c'].sample((2, ))))

        alg_step = actor.predict_step(batch, ())
        self.assertEqual(
            alg_step.output['a'].log_prob(action_spec['a'].sample()),
            torch.tensor(0.))
        self.assertEqual(
            alg_step.output['b'].log_prob(action_spec['b'].sample()),
            -torch.tensor(30.).log())
        self.assertEqual(
            alg_step.output['c'].log_prob(action_spec['c'].sample()),
            -torch.tensor(64.).log())
Example #29
0
    def test_actor_networks(self, lstm_hidden_size):
        obs_spec = TensorSpec((3, 20, 20), torch.float32)
        action_spec = BoundedTensorSpec((5, ), torch.float32, 2., 3.)
        conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
        fc_layer_params = (10, 8)

        image = obs_spec.zeros(outer_dims=(1, ))

        network_ctor, state = self._init(lstm_hidden_size)

        actor_net = network_ctor(obs_spec,
                                 action_spec,
                                 conv_layer_params=conv_layer_params,
                                 fc_layer_params=fc_layer_params)

        action, state = actor_net(image, state)

        # (batch_size, num_actions)
        self.assertEqual(action.shape, (1, 5))
Example #30
0
    def _init(self, lstm_hidden_size):
        self._action_spec = BoundedTensorSpec((), torch.int64, 0, 2)
        self._num_actions = self._action_spec.maximum - self._action_spec.minimum + 1

        if lstm_hidden_size is not None:
            network_ctor = functools.partial(QRNNNetwork,
                                             lstm_hidden_size=lstm_hidden_size)
            if isinstance(lstm_hidden_size, int):
                lstm_hidden_size = [lstm_hidden_size]
            state = []
            for size in lstm_hidden_size:
                state.append((torch.randn((
                    1,
                    size,
                ), dtype=torch.float32), ) * 2)
        else:
            network_ctor = QNetwork
            state = ()
        return network_ctor, state