def _restore_op(self, iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( io_ops.read_file(self._iterator_checkpoint_prefix_local()), dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op
def testInvalidInput(self): with self.test_session(): serialized = array_ops.placeholder(dtypes.string) tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16) with self.assertRaisesOpError( "Could not parse `serialized` as TensorProto: 'bogus'"): tensor.eval(feed_dict={serialized: "bogus"}) with self.assertRaisesOpError( r"Expected `serialized` to be a scalar, got shape: \[1\]"): tensor.eval(feed_dict={serialized: ["bogus"]})
def testTypeMismatch(self): with self.test_session(): expected = np.random.rand(3, 4, 5).astype(np.uint8) tensor_proto = tensor_util.make_tensor_proto(expected) serialized = array_ops.placeholder(dtypes.string) tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16) with self.assertRaisesOpError( r"Type mismatch between parsed tensor \(uint8\) and dtype " r"\(uint16\)"): tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
def testToUint8(self): with self.test_session(): expected = np.random.rand(3, 4, 5).astype(np.uint8) tensor_proto = tensor_util.make_tensor_proto(expected) serialized = array_ops.placeholder(dtypes.string) tensor = parsing_ops.parse_tensor(serialized, dtypes.uint8) result = tensor.eval( feed_dict={serialized: tensor_proto.SerializeToString()}) self.assertAllEqual(expected, result)
def _restore_op(iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( io_ops.read_file(_path()), dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op