예제 #1
0
 def test_output_is_as_expected(self):
     inp = tf.zeros((3, 1), dtype=tf.float32)
     expected_output = tf.zeros((3, 5), dtype=tf.float32)
     cat = CategoricalAttribute(2, 5)
     output = cat(inp)
     self.assertAllClose(expected_output, output)
     self.assertEqual(expected_output.dtype, output.dtype)
예제 #2
0
 def test_embed_instance_called_correctly(self):
     inp = tf.zeros((3, 1), dtype=tf.float32)
     cat = CategoricalAttribute(2, 5)
     cat(inp)
     self.assertAllClose(get_call_args(self._mock_embed_instance),
                         [[tf.zeros((3, 1), dtype=tf.int32)]])
     self.assertEqual(
         get_call_args(self._mock_embed_instance)[0][0].dtype, tf.int32)
예제 #3
0
 def make_embedder():
     return CategoricalAttribute(len(category_values),
                                 attr_embedding_dim,
                                 name=attribute_type +
                                 '_cat_embedder')
예제 #4
0
 def test_output_tensorspec(self):
     cat = CategoricalAttribute(2, 5)
     inp = tf.zeros((3, 1), dtype=tf.float32)
     output = cat(inp)
     np.testing.assert_array_equal(tf.TensorShape([3, 5]), output.shape)
     np.testing.assert_equal(output.dtype, tf.float32)
예제 #5
0
 def make_embedder():
     return CategoricalAttribute(num_categories,
                                 attr_embedding_dim,
                                 name=attr_typ + '_cat_embedder')
예제 #6
0
 def test_embed_invoked_correctly(self):
     attr_embedding_dim = 5
     cat = CategoricalAttribute(2, 5)
     cat(tf.zeros((3, 1), tf.float32))
     self._mock_embed_class.assert_called_once_with(2, attr_embedding_dim)