Beispiel #1
0
    def testSetTensorShapeEmpty(self):
        tensor = array_ops.placeholder(shape=[None, 3, 5],
                                       dtype=dtypes.float32)
        self.assertEqual([None, 3, 5], tensor.shape.as_list())

        convert_saved_model.set_tensor_shapes([tensor], {})
        self.assertEqual([None, 3, 5], tensor.shape.as_list())
Beispiel #2
0
    def testSetTensorShapeNoneValid(self):
        tensor = array_ops.placeholder(dtype=dtypes.float32)
        self.assertEqual(None, tensor.shape)

        convert_saved_model.set_tensor_shapes([tensor],
                                              {"Placeholder": [1, 3, 5]})
        self.assertEqual([1, 3, 5], tensor.shape.as_list())
  def testSetTensorShapeDimensionInvalid(self):
    # Tests set_tensor_shape where the shape passed in is incompatiable.
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    with self.assertRaises(ValueError) as error:
      convert_saved_model.set_tensor_shapes([tensor],
                                            {"Placeholder": [1, 5, 5]})
    self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
                  str(error.exception))
    self.assertEqual([None, 3, 5], tensor.shape.as_list())
Beispiel #4
0
    def testSetTensorShapeInvalid(self):
        tensor = array_ops.placeholder(shape=[None, 3, 5],
                                       dtype=dtypes.float32)
        self.assertEqual([None, 3, 5], tensor.shape.as_list())

        with self.assertRaises(ValueError) as error:
            convert_saved_model.set_tensor_shapes([tensor],
                                                  {"invalid-input": [5, 3, 5]})
        self.assertEqual(
            "Invalid tensor 'invalid-input' found in tensor shapes map.",
            str(error.exception))
        self.assertEqual([None, 3, 5], tensor.shape.as_list())
  def testSetTensorShapeArrayInvalid(self):
    # Tests set_tensor_shape where the tensor name passed in doesn't exist.
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    with self.assertRaises(ValueError) as error:
      convert_saved_model.set_tensor_shapes([tensor],
                                            {"invalid-input": [5, 3, 5]})
    self.assertEqual(
        "Invalid tensor 'invalid-input' found in tensor shapes map.",
        str(error.exception))
    self.assertEqual([None, 3, 5], tensor.shape.as_list())
  def testSetTensorShapeNoneValid(self):
    tensor = array_ops.placeholder(dtype=dtypes.float32)
    self.assertEqual(None, tensor.shape)

    convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
    self.assertEqual([1, 3, 5], tensor.shape.as_list())
  def testSetTensorShapeEmpty(self):
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    convert_saved_model.set_tensor_shapes([tensor], {})
    self.assertEqual([None, 3, 5], tensor.shape.as_list())