def testConvert(self, param_names, input_specs, attr_specs, inputs,
                 outputs, inferred):
     api_info = self.makeApiInfoFromParamSpecs("TestFunc", param_names,
                                               input_specs, attr_specs)
     tensor_converter = self.makeTensorConverter()
     param_values = inputs()
     actual_inferred = Convert(api_info, tensor_converter, param_values)
     self.assertInferredEqual(api_info, actual_inferred, inferred)
     self.assertParamsEqual(param_values, outputs())
 def testConvertError(self,
                      param_names,
                      input_specs,
                      attr_specs,
                      inputs,
                      message,
                      exception=TypeError):
     api_info = self.makeApiInfoFromParamSpecs("TestFunc", param_names,
                                               input_specs, attr_specs)
     tensor_converter = self.makeTensorConverter()
     param_values = inputs()
     with self.assertRaisesRegex(exception, message):
         Convert(api_info, tensor_converter, param_values)
    def testConvertMultipleAttributes(self):
        attr_specs = {"x": "list(int)", "y": "shape", "z": "float"}
        api_info = self.makeApiInfoFromParamSpecs("ConvertAttributes",
                                                  ["x", "y", "z"], {},
                                                  attr_specs)
        tensor_converter = self.makeTensorConverter()

        params = [[1, 2.0, np.array(3.0)], [1, 2], 10]
        inferred = Convert(api_info, tensor_converter, params)

        self.assertEqual(inferred.types, [])
        self.assertEqual(inferred.type_lists, [])
        self.assertEqual(inferred.lengths, [])
        self.assertLen(params, 3)
        self.assertEqual(
            params, [[1, 2, 3], tensor_shape.as_shape([1, 2]), 10.0])
        self.assertIsInstance(params[0][0], int)
        self.assertIsInstance(params[1], tensor_shape.TensorShape)
        self.assertIsInstance(params[2], float)
    def testConvertAttribute(self, attr_type, attr_val, expected):
        api_info = self.makeApiInfoFromParamSpecs("ConvertAttributes", ["x"],
                                                  {}, {"x": attr_type})
        tensor_converter = self.makeTensorConverter()

        params = [attr_val]
        inferred = Convert(api_info, tensor_converter, params)
        self.assertEqual(inferred.types, [])
        self.assertEqual(inferred.type_lists, [])
        self.assertEqual(inferred.lengths, [])
        self.assertLen(params, 1)
        actual = params[0]
        self.assertEqual(actual, expected)

        # Check that we got the actual types we expected.  (Note that in Python,
        # two values may be equal even if they have different types.)
        self.assertIs(type(actual), type(expected))
        if isinstance(expected, list):
            self.assertLen(actual, len(expected))
            for (actual_item, expected_item) in zip(actual, expected):
                self.assertIs(type(actual_item), type(expected_item))
 def testConvertAttributeError(self, attr_type, attr_val, message):
     api_info = self.makeApiInfoFromParamSpecs("Foo", ["x"], {},
                                               {"x": attr_type})
     tensor_converter = self.makeTensorConverter()
     with self.assertRaisesRegex(TypeError, message):
         Convert(api_info, tensor_converter, [attr_val])