def testConvert(self, param_names, input_specs, attr_specs, inputs, outputs, inferred): api_info = self.makeApiInfoFromParamSpecs("TestFunc", param_names, input_specs, attr_specs) tensor_converter = self.makeTensorConverter() param_values = inputs() actual_inferred = Convert(api_info, tensor_converter, param_values) self.assertInferredEqual(api_info, actual_inferred, inferred) self.assertParamsEqual(param_values, outputs())
def testConvertError(self, param_names, input_specs, attr_specs, inputs, message, exception=TypeError): api_info = self.makeApiInfoFromParamSpecs("TestFunc", param_names, input_specs, attr_specs) tensor_converter = self.makeTensorConverter() param_values = inputs() with self.assertRaisesRegex(exception, message): Convert(api_info, tensor_converter, param_values)
def testConvertMultipleAttributes(self): attr_specs = {"x": "list(int)", "y": "shape", "z": "float"} api_info = self.makeApiInfoFromParamSpecs("ConvertAttributes", ["x", "y", "z"], {}, attr_specs) tensor_converter = self.makeTensorConverter() params = [[1, 2.0, np.array(3.0)], [1, 2], 10] inferred = Convert(api_info, tensor_converter, params) self.assertEqual(inferred.types, []) self.assertEqual(inferred.type_lists, []) self.assertEqual(inferred.lengths, []) self.assertLen(params, 3) self.assertEqual( params, [[1, 2, 3], tensor_shape.as_shape([1, 2]), 10.0]) self.assertIsInstance(params[0][0], int) self.assertIsInstance(params[1], tensor_shape.TensorShape) self.assertIsInstance(params[2], float)
def testConvertAttribute(self, attr_type, attr_val, expected): api_info = self.makeApiInfoFromParamSpecs("ConvertAttributes", ["x"], {}, {"x": attr_type}) tensor_converter = self.makeTensorConverter() params = [attr_val] inferred = Convert(api_info, tensor_converter, params) self.assertEqual(inferred.types, []) self.assertEqual(inferred.type_lists, []) self.assertEqual(inferred.lengths, []) self.assertLen(params, 1) actual = params[0] self.assertEqual(actual, expected) # Check that we got the actual types we expected. (Note that in Python, # two values may be equal even if they have different types.) self.assertIs(type(actual), type(expected)) if isinstance(expected, list): self.assertLen(actual, len(expected)) for (actual_item, expected_item) in zip(actual, expected): self.assertIs(type(actual_item), type(expected_item))
def testConvertAttributeError(self, attr_type, attr_val, message): api_info = self.makeApiInfoFromParamSpecs("Foo", ["x"], {}, {"x": attr_type}) tensor_converter = self.makeTensorConverter() with self.assertRaisesRegex(TypeError, message): Convert(api_info, tensor_converter, [attr_val])