def testMultipleInputs(self): inputs_info = { "x": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=False), "y": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=False), } def _check(dict_inputs): self.assertEqual(len(dict_inputs), 2) for key in ("x", "y"): self.assertEqual(dict_inputs[key].dtype, tf.float32) self.assertTrue(dict_inputs[key].shape.is_compatible_with([None])) _check(module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]}, inputs_info)) with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): module._convert_dict_inputs(None, inputs_info) with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): module._convert_dict_inputs({}, inputs_info) with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): module._convert_dict_inputs({"z": 1}, inputs_info) with self.assertRaisesRegexp( TypeError, "Signature expects multiple inputs. Use a dict."): module._convert_dict_inputs(1, inputs_info)
def get_output_info_dict(self, signature=None, tags=None): result = { "default": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=False), } if tags == set(["special"]) and signature == "extra": result["z"] = result["default"] return result
def get_input_info_dict(self, signature=None, tags=None): result = { "x": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=(signature == "sparse" and tags == set(["special"]))), } if tags == set(["special"]) and signature == "extra": result["y"] = result["x"] return result
def testSingleInput(self): inputs_info = { "x": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=False), } def _check(dict_inputs): self.assertEqual(len(dict_inputs), 1) self.assertEqual(dict_inputs["x"].dtype, tf.float32) self.assertTrue(dict_inputs["x"].shape.is_compatible_with([None])) _check(module._convert_dict_inputs([1, 2], inputs_info)) _check(module._convert_dict_inputs({"x": [1, 2]}, inputs_info)) with self.assertRaisesRegexp(TypeError, r"missing \['x'\]"): module._convert_dict_inputs(None, inputs_info) with self.assertRaisesRegexp(TypeError, r"extra given \['y'\]"): module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]}, inputs_info)
def get_input_info_dict(self, signature=None, tags=None): if signature == "ragged" and tags == set(["special"]): result = { "x": tensor_info.ParsedTensorInfo.from_type_spec( type_spec=tf.RaggedTensorSpec( shape=[None, None, None, 3], dtype=tf.float32, ragged_rank=2)), } else: result = { "x": tensor_info.ParsedTensorInfo( tf.float32, tf.TensorShape([None]), is_sparse=(signature == "sparse" and tags == set(["special"]))), } if tags == set(["special"]) and signature == "extra": result["y"] = result["x"] return result