def test_keyed_features_filter(self): gfile.Glob = self._orig_glob lines = [ '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' ] filename = self._create_temp_file("\n".join(lines)) batch_size = 2 queue_capacity = 4 name = "my_batch" features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} def filter_fn(keys, examples_json): del keys serialized = parsing_ops.decode_json_example(examples_json) examples = parsing_ops.parse_example(serialized, features) return math_ops.less(examples["age"], 2) with ops.Graph().as_default() as g, self.session(graph=g) as session: keys, inputs = graph_io._read_keyed_batch_examples_helper( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, read_batch_size=batch_size, queue_capacity=queue_capacity, filter_fn=filter_fn, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertAllEqual((None,), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) # First batch of two filtered examples. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual( [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"], out_keys) self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")], out_vals) # Second batch will only have one filtered example as that's the only # remaining example that satisfies the filtering criterion. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) self.assertAllEqual([lines[3].encode("utf-8")], out_vals) # Exhausted input. with self.assertRaises(errors.OutOfRangeError): session.run((keys, inputs)) coord.request_stop() coord.join(threads)
def test_keyed_features_filter(self): gfile.Glob = self._orig_glob lines = [ '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' ] filename = self._create_temp_file("\n".join(lines)) batch_size = 2 queue_capacity = 4 name = "my_batch" features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} def filter_fn(keys, examples_json): del keys serialized = parsing_ops.decode_json_example(examples_json) examples = parsing_ops.parse_example(serialized, features) return math_ops.less(examples["age"], 2) with ops.Graph().as_default() as g, self.test_session(graph=g) as session: keys, inputs = graph_io._read_keyed_batch_examples_helper( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, read_batch_size=batch_size, queue_capacity=queue_capacity, filter_fn=filter_fn, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertAllEqual((None,), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) # First batch of two filtered examples. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual( [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"], out_keys) self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")], out_vals) # Second batch will only have one filtered example as that's the only # remaining example that satisfies the filtering criterion. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) self.assertAllEqual([lines[3].encode("utf-8")], out_vals) # Exhausted input. with self.assertRaises(errors.OutOfRangeError): session.run((keys, inputs)) coord.request_stop() coord.join(threads)