Exemple #1
0
 def _restore_op(self, iterator_resource):
   iterator_state_variant = parsing_ops.parse_tensor(
       io_ops.read_file(self._iterator_checkpoint_prefix_local()),
       dtypes.variant)
   restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
                                                     iterator_state_variant)
   return restore_op
  def testInvalidInput(self):
    with self.test_session():
      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)

      with self.assertRaisesOpError(
          "Could not parse `serialized` as TensorProto: 'bogus'"):
        tensor.eval(feed_dict={serialized: "bogus"})

      with self.assertRaisesOpError(
          r"Expected `serialized` to be a scalar, got shape: \[1\]"):
        tensor.eval(feed_dict={serialized: ["bogus"]})
  def testTypeMismatch(self):
    with self.test_session():
      expected = np.random.rand(3, 4, 5).astype(np.uint8)
      tensor_proto = tensor_util.make_tensor_proto(expected)

      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)

      with self.assertRaisesOpError(
          r"Type mismatch between parsed tensor \(uint8\) and dtype "
          r"\(uint16\)"):
        tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
  def testToUint8(self):
    with self.test_session():
      expected = np.random.rand(3, 4, 5).astype(np.uint8)
      tensor_proto = tensor_util.make_tensor_proto(expected)

      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint8)

      result = tensor.eval(
          feed_dict={serialized: tensor_proto.SerializeToString()})

      self.assertAllEqual(expected, result)
Exemple #5
0
  def testInvalidInput(self):
    with self.test_session():
      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)

      with self.assertRaisesOpError(
          "Could not parse `serialized` as TensorProto: 'bogus'"):
        tensor.eval(feed_dict={serialized: "bogus"})

      with self.assertRaisesOpError(
          r"Expected `serialized` to be a scalar, got shape: \[1\]"):
        tensor.eval(feed_dict={serialized: ["bogus"]})
Exemple #6
0
  def testTypeMismatch(self):
    with self.test_session():
      expected = np.random.rand(3, 4, 5).astype(np.uint8)
      tensor_proto = tensor_util.make_tensor_proto(expected)

      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)

      with self.assertRaisesOpError(
          r"Type mismatch between parsed tensor \(uint8\) and dtype "
          r"\(uint16\)"):
        tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
Exemple #7
0
  def testToUint8(self):
    with self.test_session():
      expected = np.random.rand(3, 4, 5).astype(np.uint8)
      tensor_proto = tensor_util.make_tensor_proto(expected)

      serialized = array_ops.placeholder(dtypes.string)
      tensor = parsing_ops.parse_tensor(serialized, dtypes.uint8)

      result = tensor.eval(
          feed_dict={serialized: tensor_proto.SerializeToString()})

      self.assertAllEqual(expected, result)
Exemple #8
0
 def _restore_op(iterator_resource):
   iterator_state_variant = parsing_ops.parse_tensor(
       io_ops.read_file(_path()), dtypes.variant)
   restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
                                                     iterator_state_variant)
   return restore_op
 def _restore_op(iterator_resource):
   iterator_state_variant = parsing_ops.parse_tensor(
       io_ops.read_file(_path()), dtypes.variant)
   restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
                                                     iterator_state_variant)
   return restore_op