Exemple #1
0
    def test_keyed_parse_json(self):
        gfile.Glob = self._orig_glob
        filename = self._create_temp_file(
            '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
            '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
            '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n'
        )

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"

        with ops.Graph().as_default() as g, self.test_session(
                graph=g) as session:
            dtypes = {
                "age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)
            }
            parse_fn = lambda example: parsing_ops.parse_single_example(  # pylint: disable=g-long-lambda
                parsing_ops.decode_json_example(example), dtypes)
            keys, inputs = graph_io.read_keyed_batch_examples(
                filename,
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                parse_fn=parse_fn,
                name=name)
            self.assertAllEqual((None, ), keys.get_shape().as_list())
            self.assertEqual(1, len(inputs))
            self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
            session.run(variables.local_variables_initializer())

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            key, age = session.run([keys, inputs["age"]])
            self.assertAllEqual(age, [[0]])
            self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
            key, age = session.run([keys, inputs["age"]])
            self.assertAllEqual(age, [[1]])
            self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
            key, age = session.run([keys, inputs["age"]])
            self.assertAllEqual(age, [[2]])
            self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
  def test_keyed_parse_json(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file(
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n')

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
      parse_fn = lambda example: parsing_ops.parse_single_example(  # pylint: disable=g-long-lambda
          parsing_ops.decode_json_example(example), dtypes)
      keys, inputs = graph_io.read_keyed_batch_examples(
          filename,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          parse_fn=parse_fn,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertEqual(1, len(inputs))
      self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[0]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[1]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[2]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
  def _testRoundTrip(self, examples):
    with self.test_session() as sess:
      examples = np.array(examples, dtype=np.object)

      json_tensor = constant_op.constant(
          [json_format.MessageToJson(m) for m in examples.flatten()],
          shape=examples.shape,
          dtype=dtypes.string)
      binary_tensor = parsing_ops.decode_json_example(json_tensor)
      binary_val = sess.run(binary_tensor)

      if examples.shape:
        self.assertShapeEqual(binary_val, json_tensor)
        for input_example, output_binary in zip(
            np.array(examples).flatten(), binary_val.flatten()):
          output_example = example_pb2.Example()
          output_example.ParseFromString(output_binary)
          self.assertProtoEquals(input_example, output_example)
      else:
        output_example = example_pb2.Example()
        output_example.ParseFromString(binary_val)
        self.assertProtoEquals(examples.item(), output_example)
Exemple #4
0
  def _testRoundTrip(self, examples):
    with self.test_session() as sess:
      examples = np.array(examples, dtype=np.object)

      json_tensor = constant_op.constant(
          [json_format.MessageToJson(m) for m in examples.flatten()],
          shape=examples.shape,
          dtype=dtypes.string)
      binary_tensor = parsing_ops.decode_json_example(json_tensor)
      binary_val = sess.run(binary_tensor)

      if examples.shape:
        self.assertShapeEqual(binary_val, json_tensor)
        for input_example, output_binary in zip(
            np.array(examples).flatten(), binary_val.flatten()):
          output_example = example_pb2.Example()
          output_example.ParseFromString(output_binary)
          self.assertProtoEquals(input_example, output_example)
      else:
        output_example = example_pb2.Example()
        output_example.ParseFromString(binary_val)
        self.assertProtoEquals(examples.item(), output_example)
Exemple #5
0
 def filter_fn(keys, examples_json):
   del keys
   serialized = parsing_ops.decode_json_example(examples_json)
   examples = parsing_ops.parse_example(serialized, features)
   return math_ops.less(examples["age"], 2)
 def testInvalidSyntax(self):
   with self.test_session() as sess:
     json_tensor = constant_op.constant(["{]"])
     binary_tensor = parsing_ops.decode_json_example(json_tensor)
     with self.assertRaisesOpError("Error while parsing JSON"):
       sess.run(binary_tensor)
Exemple #7
0
 def testInvalidSyntax(self):
   with self.test_session() as sess:
     json_tensor = constant_op.constant(["{]"])
     binary_tensor = parsing_ops.decode_json_example(json_tensor)
     with self.assertRaisesOpError("Error while parsing JSON"):
       sess.run(binary_tensor)
Exemple #8
0
 def filter_fn(keys, examples_json):
     del keys
     serialized = parsing_ops.decode_json_example(examples_json)
     examples = parsing_ops.parse_example(serialized, features)
     return math_ops.less(examples["age"], 2)