def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, type_signature, count, value): proto = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = anonymous_tuple.AnonymousTuple([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_raises_type_error(self, type_signature, count): with self.assertRaises(TypeError): tensorflow_computation_factory.create_replicate_input( type_signature, count)