예제 #1
0
    def _log_prob_with_logits(self, x):
        # flatten `x` and `logits` as requirement of
        # `tf.nn.softmax_cross_entropy_with_logits`
        x = tf.cast(x, dtype=self.param_dtype)
        x, logits = maybe_explicit_broadcast(x, self.logits)

        if x.get_shape().ndims == 2:
            x_2d, logits_2d = x, logits
        else:
            dynamic_shape = tf.stack([-1, tf.shape(x)[-1]], axis=0)
            x_2d = tf.reshape(x, dynamic_shape)
            logits_2d = tf.reshape(logits, dynamic_shape)

            static_shape = (tf.TensorShape([None
                                            ]).concatenate(x.get_shape()[-1:]))
            x_2d.set_shape(static_shape)
            logits_2d.set_shape(static_shape)

        # derive the flatten log p(x)
        log_p_2d = -tf.nn.softmax_cross_entropy_with_logits(
            labels=x_2d,
            logits=logits_2d,
        )

        # reshape log p(x) back into desired shape
        if x.get_shape().ndims == 2:
            log_p = log_p_2d
        else:
            static_shape = x.get_shape()[:-1]
            log_p = tf.reshape(log_p_2d, tf.shape(x)[:-1])
            log_p.set_shape(static_shape)

        return log_p
예제 #2
0
 def _log_prob(self, x):
     x = tf.cast(x, dtype=self.param_dtype)
     if self._probs_is_derived:
         logits = self.logits
         x, logits = maybe_explicit_broadcast(x, logits)
         return self._check_numerics(
             -tf.nn.sigmoid_cross_entropy_with_logits(labels=x,
                                                      logits=logits),
             'log_prob')
     else:
         # TODO: check whether this is better than using derived logits.
         log_p = tf.log(self._probs_clipped)
         log_one_minus_p = tf.log1p(-self._probs_clipped)
         return self._check_numerics(x * log_p + (1. - x) * log_one_minus_p,
                                     'log_prob')
예제 #3
0
    def variational(self, x, y=None, z=None, z_samples=NOT_SPECIFIED):
        """Derive the `StochasticTensor` objects of variational network.

        The behavior of this method is not affected by `y_in_variational_model`,
        thus `y` will always be used if it is specified.

        Parameters
        ----------
        x : tf.Tensor
            The model inputs.

        y : tf.Tensor
            Optional conditional input for CVAE.  See [1] for more details.

            [1] K. Sohn, H. Lee, and X. Yan, “Learning structured output
                representation using deep conditional generative models,”
                in Advances in Neural Information Processing Systems, 2015,
                pp. 3483–3491.

        z : tf.Tensor | StochasticTensor
            The latent variable observations.  (default None)

        z_samples : int | tf.Tensor | None
            If specified, override `z_samples` specified in the constructor.

        Returns
        -------
        StochasticTensor
            The derived z stochastic tensors.
        """
        if y is None:
            features = x
        else:
            features = tf.concat(maybe_explicit_broadcast(
                x, y, tail_no_broadcast_ndims=1),
                                 axis=-1)
        z_params = self._z_net(features)

        if z_samples is NOT_SPECIFIED:
            z_samples = self._z_samples
        return self._z_layer(z_params,
                             n_samples=z_samples,
                             observed=z,
                             group_event_ndims=1,
                             validate_shape=self._validate_shape)
예제 #4
0
    def model(self,
              z=None,
              y=None,
              x=None,
              z_samples=NOT_SPECIFIED,
              x_samples=None):
        """Derive the `StochasticTensor` objects of generation network.

        The behavior of this method is not affected by `y_in_generative_model`,
        thus `y` will always be used if it is specified.

        Parameters
        ----------
        z : tf.Tensor | StochasticTensor
            Optional latent variable samples.

        y : tf.Tensor
            Optional conditional input for CVAE.  See [1] for more details.

            [1] K. Sohn, H. Lee, and X. Yan, “Learning structured output
                representation using deep conditional generative models,”
                in Advances in Neural Information Processing Systems, 2015,
                pp. 3483–3491.

        x : tf.Tensor
            The observations for the output `StochasticTensor`. (default None)

        z_samples : int | tf.Tensor | None
            If specified, override `z_samples` specified in the constructor.

        x_samples : int | tf.Tensor | None
            Specify the number of samples to take for output x.
            (default None)

        Returns
        -------
        (StochasticTensor, StochasticTensor)
            The derived (z, x) stochastic tensors.
        """
        if z_samples is NOT_SPECIFIED:
            z_samples = self._z_samples
        z = self._z_prior.sample_or_observe(
            z_samples,
            observed=z,
            validate_shape=self._validate_shape,
            group_event_ndims=1,
            name='z')

        if y is None:
            features = z
        else:
            features = tf.concat(maybe_explicit_broadcast(
                z, y, tail_no_broadcast_ndims=1),
                                 axis=-1)
        x_params = self._x_net(features)

        x = self._x_layer(x_params,
                          n_samples=x_samples,
                          observed=x,
                          group_event_ndims=1,
                          validate_shape=self._validate_shape)

        return z, x
예제 #5
0
    def test_maybe_explicit_broadcast(self):
        with self.get_session() as session:
            # test on equal static shape
            x = tf.convert_to_tensor(np.arange(2))
            y = tf.convert_to_tensor(np.arange(2, 4))
            xx, yy = maybe_explicit_broadcast(x, y)
            self.assertIs(xx, x)
            self.assertEqual(xx.get_shape().as_list(), [2])
            np.testing.assert_equal(xx.eval(), [0, 1])
            self.assertIs(yy, y)
            self.assertEqual(yy.get_shape().as_list(), [2])
            np.testing.assert_equal(yy.eval(), [2, 3])

            # test on fully static same dimensional shape
            x = tf.convert_to_tensor(np.arange(2).reshape([1, 2]))
            y = tf.convert_to_tensor(np.arange(2, 4).reshape([2, 1]))
            xx, yy = maybe_explicit_broadcast(x, y)
            self.assertEqual(xx.get_shape().as_list(), [2, 2])
            np.testing.assert_equal(xx.eval(), [[0, 1], [0, 1]])
            self.assertEqual(yy.get_shape().as_list(), [2, 2])
            np.testing.assert_equal(yy.eval(), [[2, 2], [3, 3]])

            # test on fully static different dimensional shape
            x = tf.convert_to_tensor(np.arange(2).reshape([2, 1]))
            y = tf.convert_to_tensor(np.arange(2, 4))
            xx, yy = maybe_explicit_broadcast(x, y)
            self.assertEqual(xx.get_shape().as_list(), [2, 2])
            np.testing.assert_equal(xx.eval(), [[0, 0], [1, 1]])
            self.assertEqual(yy.get_shape().as_list(), [2, 2])
            np.testing.assert_equal(yy.eval(), [[2, 3], [2, 3]])

            # test on dynamic same dimensional shape
            x = tf.placeholder(tf.float32, (None, 3))
            y = tf.placeholder(tf.float32, (2, None))
            xx, yy = maybe_explicit_broadcast(x, y)
            x_data = np.arange(3).reshape((1, 3))
            y_data = np.arange(3, 5).reshape((2, 1))
            self.assertEqual(xx.get_shape().as_list(), [2, 3])
            np.testing.assert_equal(xx.eval({x: x_data, y: y_data}),
                                    [[0, 1, 2], [0, 1, 2]])
            self.assertEqual(yy.get_shape().as_list(), [2, 3])
            np.testing.assert_equal(yy.eval({x: x_data, y: y_data}),
                                    [[3, 3, 3], [4, 4, 4]])

            # test on dynamic different dimensional shape
            x = tf.placeholder(tf.float32, (None,))
            y = tf.placeholder(tf.float32, (None, 2))
            xx, yy = maybe_explicit_broadcast(x, y)
            x_data = np.arange(1, 2)
            y_data = np.arange(2, 6).reshape((2, 2))
            self.assertEqual(xx.get_shape().as_list(), [None, 2])
            np.testing.assert_equal(xx.eval({x: x_data, y: y_data}),
                                    [[1, 1], [1, 1]])
            self.assertEqual(yy.get_shape().as_list(), [None, 2])
            np.testing.assert_equal(yy.eval({x: x_data, y: y_data}),
                                    [[2, 3], [4, 5]])

            # test `tail_no_broadcast_ndims = 0`
            x = tf.convert_to_tensor(np.arange(2).reshape([1, 1, 2]))
            y = tf.convert_to_tensor(np.arange(2, 4).reshape([2, 1]))
            xx, yy = maybe_explicit_broadcast(x, y, tail_no_broadcast_ndims=0)
            self.assertEqual(xx.get_shape().as_list(), [1, 2, 2])
            np.testing.assert_equal(xx.eval(), [[[0, 1], [0, 1]]])
            self.assertEqual(yy.get_shape().as_list(), [1, 2, 2])
            np.testing.assert_equal(yy.eval(), [[[2, 2], [3, 3]]])

            # test `tail_no_broadcast_ndims = 1`
            x = tf.convert_to_tensor(np.arange(2).reshape([1, 1, 2]))
            y = tf.convert_to_tensor(np.arange(2, 4).reshape([2, 1]))
            xx, yy = maybe_explicit_broadcast(x, y, tail_no_broadcast_ndims=1)
            self.assertEqual(xx.get_shape().as_list(), [1, 2, 2])
            np.testing.assert_equal(xx.eval(), [[[0, 1], [0, 1]]])
            self.assertEqual(yy.get_shape().as_list(), [1, 2, 1])
            np.testing.assert_equal(yy.eval(), [[[2], [3]]])

            # test `tail_no_broadcast_ndims = 2`
            x = tf.convert_to_tensor(np.arange(2).reshape([1, 1, 2]))
            y = tf.convert_to_tensor(np.arange(2, 4).reshape([2, 1]))
            xx, yy = maybe_explicit_broadcast(x, y, tail_no_broadcast_ndims=2)
            self.assertEqual(xx.get_shape().as_list(), [1, 1, 2])
            np.testing.assert_equal(xx.eval(), [[[0, 1]]])
            self.assertEqual(yy.get_shape().as_list(), [1, 2, 1])
            np.testing.assert_equal(yy.eval(), [[[2], [3]]])

            # test `tail_no_broadcast_ndims = 1` on dynamic shape
            x = tf.placeholder(tf.float32, (None, None, None))
            y = tf.placeholder(tf.float32, (None, None))
            xx, yy = maybe_explicit_broadcast(x, y, tail_no_broadcast_ndims=1)
            x_data = np.arange(2).reshape([1, 1, 2])
            y_data = np.arange(2, 4).reshape([2, 1])
            self.assertEqual(xx.get_shape().as_list(), [None, None, None])
            np.testing.assert_equal(xx.eval({x: x_data, y: y_data}),
                                    [[[0, 1], [0, 1]]])
            self.assertEqual(yy.get_shape().as_list(), [None, None, None])
            np.testing.assert_equal(yy.eval({x: x_data, y: y_data}),
                                    [[[2], [3]]])

            # test `tail_no_broadcast_ndims = 1` on fully dynamic shape
            x = tf.placeholder(tf.float32)
            y = tf.placeholder(tf.float32)
            xx, yy = maybe_explicit_broadcast(x, y, tail_no_broadcast_ndims=1)
            x_data = np.arange(2).reshape([1, 1, 2])
            y_data = np.arange(2, 4).reshape([2, 1])
            self.assertEqual(xx.get_shape().ndims, None)
            np.testing.assert_equal(xx.eval({x: x_data, y: y_data}),
                                    [[[0, 1], [0, 1]]])
            self.assertEqual(yy.get_shape().ndims, None)
            np.testing.assert_equal(yy.eval({x: x_data, y: y_data}),
                                    [[[2], [3]]])

            # test error on static shape
            with self.assertRaisesRegex(
                    ValueError, '.* and .* cannot broadcast to match.*'):
                _ = maybe_explicit_broadcast(
                    tf.convert_to_tensor(np.arange(2)),
                    tf.convert_to_tensor(np.arange(3))
                )

            # test error on dynamic shape
            x = tf.placeholder(tf.float32, (None, 3))
            y = tf.placeholder(tf.float32, (2, None))
            xx, yy = maybe_explicit_broadcast(x, y)
            x_data = np.arange(3).reshape((1, 3))
            y_data = np.arange(3, 7).reshape((2, 2))
            with self.assertRaisesRegex(Exception, r'.*Incompatible shapes.*'):
                session.run([xx, yy], feed_dict={x: x_data, y: y_data})