def testSetTensorShapeEmpty(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) convert_saved_model.set_tensor_shapes([tensor], {}) self.assertEqual([None, 3, 5], tensor.shape.as_list())
def testSetTensorShapeNoneValid(self): tensor = array_ops.placeholder(dtype=dtypes.float32) self.assertEqual(None, tensor.shape) convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) self.assertEqual([1, 3, 5], tensor.shape.as_list())
def testSetTensorShapeDimensionInvalid(self): # Tests set_tensor_shape where the shape passed in is incompatiable. tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) with self.assertRaises(ValueError) as error: convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 5, 5]}) self.assertIn("The shape of tensor 'Placeholder' cannot be changed", str(error.exception)) self.assertEqual([None, 3, 5], tensor.shape.as_list())
def testSetTensorShapeInvalid(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) with self.assertRaises(ValueError) as error: convert_saved_model.set_tensor_shapes([tensor], {"invalid-input": [5, 3, 5]}) self.assertEqual( "Invalid tensor 'invalid-input' found in tensor shapes map.", str(error.exception)) self.assertEqual([None, 3, 5], tensor.shape.as_list())
def testSetTensorShapeArrayInvalid(self): # Tests set_tensor_shape where the tensor name passed in doesn't exist. tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) with self.assertRaises(ValueError) as error: convert_saved_model.set_tensor_shapes([tensor], {"invalid-input": [5, 3, 5]}) self.assertEqual( "Invalid tensor 'invalid-input' found in tensor shapes map.", str(error.exception)) self.assertEqual([None, 3, 5], tensor.shape.as_list())
def testSetTensorShapeNoneValid(self): tensor = array_ops.placeholder(dtype=dtypes.float32) self.assertEqual(None, tensor.shape) convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) self.assertEqual([1, 3, 5], tensor.shape.as_list())
def testSetTensorShapeEmpty(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) convert_saved_model.set_tensor_shapes([tensor], {}) self.assertEqual([None, 3, 5], tensor.shape.as_list())