コード例 #1
0
    def testOneHotMultiplyExactHard(self):
        inputs = tf.constant([[0., 1., 0.], [0., 0., 1.]])
        scale = tf.constant([[0., 1., 0.], [0., 0., 1.]])

        outputs = reversible.one_hot_multiply(inputs, scale)
        outputs_val = self.evaluate(outputs)
        self.assertAllEqual(outputs_val, np.array([[0., 1., 0.], [0., 1.,
                                                                  0.]]))
コード例 #2
0
    def testOneHotMultiplyExactSoft(self):
        inputs = tf.constant([[0., 1., 0.], [0., 0., 1.]])
        scale = tf.constant([[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]])

        outputs = reversible.one_hot_multiply(inputs, scale)

        scale_zero = np.array([[0., 0., 0.], [0., 0., 0.]])
        scale_one = inputs
        scale_two = np.array([[0., 0., 1.], [0., 1., 0.]])
        expected_outputs = (scale[..., 0][..., tf.newaxis] * scale_zero +
                            scale[..., 1][..., tf.newaxis] * scale_one +
                            scale[..., 2][..., tf.newaxis] * scale_two)

        actual_outputs_val, expected_outputs_val = self.evaluate(
            [outputs, expected_outputs])
        self.assertAllEqual(actual_outputs_val, expected_outputs_val)