Пример #1
0
    def test_constructor(self, batch_shape, event_shape, num_values):
        logits_shape = batch_shape + event_shape + (num_values, )
        logits_size = np.prod(logits_shape)
        logits = np.arange(logits_size, dtype=float).reshape(logits_shape)
        values = np.linspace(start=-np.ones(event_shape, dtype=float),
                             stop=np.ones(event_shape, dtype=float),
                             num=num_values,
                             axis=-1)
        distribution = distributions.DiscreteValuedDistribution(values=values,
                                                                logits=logits)

        # Check batch and event shapes.
        self.assertEqual(distribution.batch_shape, batch_shape)
        self.assertEqual(distribution.event_shape, event_shape)
        self.assertEqual(distribution.logits_parameter().shape.as_list(),
                         list(logits.shape))
        self.assertEqual(distribution.logits_parameter().shape.as_list()[-1],
                         logits.shape[-1])

        # Test slicing
        if len(batch_shape) == 1:
            slice_0_logits = distribution[1:3].logits_parameter().numpy()
            expected_slice_0_logits = distribution.logits_parameter().numpy(
            )[1:3]
            npt.assert_allclose(slice_0_logits, expected_slice_0_logits)
        elif len(batch_shape) == 2:
            slice_logits = distribution[0, 1:3].logits_parameter().numpy()
            expected_slice_logits = distribution.logits_parameter().numpy()[
                0, 1:3]
            npt.assert_allclose(slice_logits, expected_slice_logits)
        else:
            assert not batch_shape
Пример #2
0
  def __call__(self, inputs: tf.Tensor) -> tfd.Distribution:
    logits = self._distributional_layer(inputs)
    logits = tf.reshape(logits,
                        tf.concat([tf.shape(logits)[:1],  # batch size
                                   tf.shape(self._values)],
                                  axis=0))
    values = tf.cast(self._values, logits.dtype)

    return ad.DiscreteValuedDistribution(values=values, logits=logits)
Пример #3
0
    def __call__(self, inputs: tf.Tensor) -> tfd.Distribution:
        logits = self._distributional_layer(inputs)
        values = tf.cast(self._values, logits.dtype)

        return ad.DiscreteValuedDistribution(values=values, logits=logits)