def test_zero_normal_projection_net(self, state_dependent_std): """A zero-weight net generates zero actions.""" input_spec = TensorSpec((10, ), torch.float32) action_spec = TensorSpec((8, ), torch.float32) embedding = input_spec.ones(outer_dims=(2, )) net = NormalProjectionNetwork( input_size=input_spec.shape[0], action_spec=action_spec, projection_output_init_gain=0, std_bias_initializer_value=0, squash_mean=False, state_dependent_std=state_dependent_std, std_transform=math_ops.identity) out = net(embedding)[0].sample((10, )) self.assertTrue(isinstance(net.output_spec, DistributionSpec)) self.assertEqual(tuple(out.size()), ( 10, 2, ) + action_spec.shape) self.assertTrue(torch.all(out == 0))
def test_squash_mean_normal_projection_net(self, network_ctor): """A net with `sqaush_mean=True` should generate means within the action spec.""" input_spec = TensorSpec((10, ), torch.float32) embedding = input_spec.ones(outer_dims=(100, )) action_spec = TensorSpec((8, ), torch.float32) # For squashing mean, we need a bounded action spec self.assertRaises(AssertionError, network_ctor, input_spec, action_spec) action_spec = BoundedTensorSpec((2, ), torch.float32, minimum=(0, -0.01), maximum=(0.01, 0)) net = network_ctor( input_spec.shape[0], action_spec, projection_output_init_gain=1.0) dist, _ = net(embedding) self.assertTrue(dist.mean.std() > 0) self.assertTrue( torch.all(dist.mean > torch.as_tensor(action_spec.minimum))) self.assertTrue( torch.all(dist.mean < torch.as_tensor(action_spec.maximum)))