def testMultipleInputs(self):
    inputs_info = {
        "x": tensor_info.ParsedTensorInfo(
            tf.float32,
            tf.TensorShape([None]),
            is_sparse=False),
        "y": tensor_info.ParsedTensorInfo(
            tf.float32,
            tf.TensorShape([None]),
            is_sparse=False),
    }
    def _check(dict_inputs):
      self.assertEqual(len(dict_inputs), 2)
      for key in ("x", "y"):
        self.assertEqual(dict_inputs[key].dtype, tf.float32)
        self.assertTrue(dict_inputs[key].shape.is_compatible_with([None]))

    _check(module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]},
                                       inputs_info))

    with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"):
      module._convert_dict_inputs(None, inputs_info)
    with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"):
      module._convert_dict_inputs({}, inputs_info)
    with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"):
      module._convert_dict_inputs({"z": 1}, inputs_info)

    with self.assertRaisesRegexp(
        TypeError, "Signature expects multiple inputs. Use a dict."):
      module._convert_dict_inputs(1, inputs_info)
 def get_output_info_dict(self, signature=None, tags=None):
   result = {
       "default": tensor_info.ParsedTensorInfo(
           tf.float32,
           tf.TensorShape([None]),
           is_sparse=False),
   }
   if tags == set(["special"]) and signature == "extra":
     result["z"] = result["default"]
   return result
Beispiel #3
0
 def get_input_info_dict(self, signature=None, tags=None):
     result = {
         "x":
         tensor_info.ParsedTensorInfo(
             tf.float32,
             tf.TensorShape([None]),
             is_sparse=(signature == "sparse"
                        and tags == set(["special"]))),
     }
     if tags == set(["special"]) and signature == "extra":
         result["y"] = result["x"]
     return result
  def testSingleInput(self):
    inputs_info = {
        "x": tensor_info.ParsedTensorInfo(
            tf.float32,
            tf.TensorShape([None]),
            is_sparse=False),
    }
    def _check(dict_inputs):
      self.assertEqual(len(dict_inputs), 1)
      self.assertEqual(dict_inputs["x"].dtype, tf.float32)
      self.assertTrue(dict_inputs["x"].shape.is_compatible_with([None]))

    _check(module._convert_dict_inputs([1, 2], inputs_info))
    _check(module._convert_dict_inputs({"x": [1, 2]}, inputs_info))

    with self.assertRaisesRegexp(TypeError, r"missing \['x'\]"):
      module._convert_dict_inputs(None, inputs_info)

    with self.assertRaisesRegexp(TypeError, r"extra given \['y'\]"):
      module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]}, inputs_info)
 def get_input_info_dict(self, signature=None, tags=None):
   if signature == "ragged" and tags == set(["special"]):
     result = {
         "x":
             tensor_info.ParsedTensorInfo.from_type_spec(
                 type_spec=tf.RaggedTensorSpec(
                     shape=[None, None, None, 3], dtype=tf.float32,
                     ragged_rank=2)),
     }
   else:
     result = {
         "x":
             tensor_info.ParsedTensorInfo(
                 tf.float32,
                 tf.TensorShape([None]),
                 is_sparse=(signature == "sparse" and
                            tags == set(["special"]))),
     }
   if tags == set(["special"]) and signature == "extra":
     result["y"] = result["x"]
   return result