Esempio n. 1
0
    def test_returns_computation(self, type_signature, count, value):
        proto, _ = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = structure.Struct([(None, value)] * count)
        self.assertEqual(actual_result, expected_result)
Esempio n. 2
0
    def test_returns_computation(self, type_signature, count, value):
        proto = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = anonymous_tuple.AnonymousTuple([(None, value)] *
                                                         count)
        self.assertEqual(actual_result, expected_result)
Esempio n. 3
0
 def test_raises_type_error(self, type_signature, count):
     with self.assertRaises(TypeError):
         tensorflow_computation_factory.create_replicate_input(
             type_signature, count)