def test_keyed_parse_json(self): gfile.Glob = self._orig_glob filename = self._create_temp_file( '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n' '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n' '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n' ) batch_size = 1 queue_capacity = 5 name = "my_batch" with ops.Graph().as_default() as g, self.test_session( graph=g) as session: dtypes = { "age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64) } parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda parsing_ops.decode_json_example(example), dtypes) keys, inputs = graph_io.read_keyed_batch_examples( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, parse_fn=parse_fn, name=name) self.assertAllEqual((None, ), keys.get_shape().as_list()) self.assertEqual(1, len(inputs)) self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[0]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":1"]) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[1]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":2"]) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[2]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":3"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)
def test_keyed_parse_json(self): gfile.Glob = self._orig_glob filename = self._create_temp_file( '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n' '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n' '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n') batch_size = 1 queue_capacity = 5 name = "my_batch" with ops.Graph().as_default() as g, self.test_session(graph=g) as session: dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)} parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda parsing_ops.decode_json_example(example), dtypes) keys, inputs = graph_io.read_keyed_batch_examples( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, parse_fn=parse_fn, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertEqual(1, len(inputs)) self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[0]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":1"]) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[1]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":2"]) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[2]]) self.assertAllEqual(key, [filename.encode("utf-8") + b":3"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)
def test_keyed_read_text_lines(self): gfile.Glob = self._orig_glob filename = self._create_temp_file("ABC\nDEF\nGHK\n") batch_size = 1 queue_capacity = 5 name = "my_batch" with ops.Graph().as_default() as g, self.test_session(graph=g) as session: keys, inputs = graph_io.read_keyed_batch_examples( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertAllEqual((None,), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) self.assertAllEqual( session.run([keys, inputs]), [[filename.encode("utf-8") + b":1"], [b"ABC"]]) self.assertAllEqual( session.run([keys, inputs]), [[filename.encode("utf-8") + b":2"], [b"DEF"]]) self.assertAllEqual( session.run([keys, inputs]), [[filename.encode("utf-8") + b":3"], [b"GHK"]]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)