Exemple #1
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 filters=256,
                 hidden_size=256,
                 initial_game_over_bias=0.):
        super().__init__(observation_spec, name="PredictionNet")
        in_channels, h, w = observation_spec.shape

        output_weight_initializer = torch.nn.init.zeros_

        self._value_head = nn.Sequential(
            alf.layers.Conv2D(in_channels, 1, kernel_size=1),
            alf.layers.Reshape([-1]),
            alf.layers.FC(input_size=h * w,
                          output_size=hidden_size,
                          activation=torch.relu_,
                          use_bn=False),
            alf.layers.FC(input_size=hidden_size,
                          output_size=1,
                          activation=torch.tanh,
                          kernel_initializer=output_weight_initializer),
            alf.layers.Reshape(()))

        self._reward_head = nn.Sequential(
            alf.layers.Conv2D(in_channels, 1, kernel_size=1),
            alf.layers.Reshape([-1]),
            alf.layers.FC(input_size=h * w,
                          output_size=hidden_size,
                          activation=torch.relu_,
                          use_bn=False),
            alf.layers.FC(input_size=hidden_size,
                          output_size=1,
                          activation=torch.tanh,
                          kernel_initializer=output_weight_initializer),
            alf.layers.Reshape(()))

        self._game_over_head = nn.Sequential(
            alf.layers.Conv2D(in_channels, 1, kernel_size=1),
            alf.layers.Reshape([-1]),
            alf.layers.FC(input_size=h * w,
                          output_size=hidden_size,
                          activation=torch.relu_,
                          use_bn=False),
            alf.layers.FC(input_size=hidden_size,
                          output_size=1,
                          bias_init_value=initial_game_over_bias,
                          kernel_initializer=output_weight_initializer),
            alf.layers.Reshape(()))

        self._action_head = nn.Sequential(
            alf.layers.Conv2D(in_channels, filters, kernel_size=3, padding=1),
            alf.layers.Reshape([-1]))

        self._action_proj = CategoricalProjectionNetwork(
            input_size=h * w * filters,
            action_spec=action_spec,
            logits_init_output_factor=1e-10)
Exemple #2
0
    def test_uniform_projection_net(self):
        """A zero-weight net generates uniform actions."""
        input_spec = TensorSpec((10, ), torch.float32)
        embedding = input_spec.ones(outer_dims=(1, ))

        net = CategoricalProjectionNetwork(input_size=input_spec.shape[0],
                                           action_spec=BoundedTensorSpec(
                                               (1, ), minimum=0, maximum=4),
                                           logits_init_output_factor=0)
        dist, _ = net(embedding)
        self.assertTrue(isinstance(net.output_spec, DistributionSpec))
        self.assertEqual(dist.batch_shape, (1, ))
        self.assertEqual(dist.base_dist.batch_shape, (1, 1))
        self.assertTrue(torch.all(dist.base_dist.probs == 0.2))
Exemple #3
0
    def test_close_uniform_projection_net(self):
        """A random-weight net generates close-uniform actions on average."""
        input_spec = TensorSpec((10, ), torch.float32)
        embeddings = input_spec.ones(outer_dims=(100, ))

        net = CategoricalProjectionNetwork(input_size=input_spec.shape[0],
                                           action_spec=BoundedTensorSpec(
                                               (3, 2), minimum=0, maximum=4),
                                           logits_init_output_factor=1.0)
        dists, _ = net(embeddings)
        self.assertEqual(dists.batch_shape, (100, ))
        self.assertEqual(dists.base_dist.batch_shape, (100, 3, 2))
        self.assertTrue(dists.base_dist.probs.std() > 0)
        self.assertTrue(
            torch.isclose(dists.base_dist.probs.mean(), torch.as_tensor(0.2)))
Exemple #4
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 trunk_net_ctor,
                 initial_game_over_bias=0.0):
        """
        Args:
            observation_spec (TensorSpec): describing the observation.
            action_spec (BoundedTensorSpec): describing the action.
            trunk_net_ctor (Callable): called as ``trunk_net_ctor(input_tensor_spec=observation_spec)``
                to created a network which taks observation as input and output a
                hidden representation which will be used as input for predicting
                value, reward, action_distribution and game_over_logit
            initial_game_over_bias (float): initial bias for predicting the.
                logit of game_over. Sugguest to use ``log(game_over_prob/(1 - game_over_prob))``
        """
        super().__init__(observation_spec, name="SimplePredictionNet")

        self._trunk_net = trunk_net_ctor(input_tensor_spec=observation_spec)
        dim = self._trunk_net.output_spec.shape[0]
        self._value_layer = alf.layers.FC(
            dim, 1, kernel_initializer=torch.nn.init.zeros_)
        self._reward_layer = alf.layers.FC(
            dim, 1, kernel_initializer=torch.nn.init.zeros_)

        if action_spec.is_continuous:
            self._action_net = StableNormalProjectionNetwork(
                input_size=dim,
                action_spec=action_spec,
                state_dependent_std=True,
                scale_distribution=True,
                dist_squashing_transform=dist_utils.Softsign())
        else:
            self._action_net = CategoricalProjectionNetwork(
                input_size=dim,
                action_spec=action_spec,
                logits_init_output_factor=1e-10)

        self._game_over_logit_thresh = 1.0
        self._game_over_layer = alf.layers.FC(
            dim,
            1,
            kernel_initializer=torch.nn.init.zeros_,
            bias_init_value=initial_game_over_bias)