def test_read_tag_logits_with_pattern_v2(self): # Create fake logs for the read_tag_logits() function to consume. tempdir = self.get_temp_dir() writer = tf2.summary.create_file_writer(tempdir, max_queue=0) with writer.as_default(): # Events matching pattern v2. tf2.summary.scalar('rltaglogits/op_indices_0/0', 1.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_0/1', 2.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_0/2', 3.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/0', 4.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/1', 5.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/2', 6.0, step=42) tf2.summary.scalar('rltaglogits/filters_indices_0/0', 7.0, step=42) # Events not matching any pattern. tf2.summary.scalar('global_step/sec', 10.0, step=42) self.evaluate(writer.init()) self.evaluate(tf.summary.all_v2_summary_ops()) self.evaluate(writer.flush()) # Try to read the events from file. self.assertAllClose( analyze_mobile_search_lib.read_tag_logits(tempdir), { 42: { 'op_indices': [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 'filters_indices': [[7.0]], } })
def test_read_tag_logits_with_two_valid_steps(self): # Create fake logs for the read_tag_logits() function to consume. tempdir = self.get_temp_dir() writer = tf2.summary.create_file_writer(tempdir, max_queue=0) values = tf.placeholder(tf.float32, [8]) global_step = tf.placeholder(tf.int64, ()) with writer.as_default(): tf2.summary.scalar( 'rltaglogits/op_indices_0/0', values[0], step=global_step) tf2.summary.scalar( 'rltaglogits/op_indices_0/1', values[1], step=global_step) tf2.summary.scalar( 'rltaglogits/op_indices_0/2', values[2], step=global_step) tf2.summary.scalar( 'rltaglogits/op_indices_1/0', values[3], step=global_step) tf2.summary.scalar( 'rltaglogits/op_indices_1/1', values[4], step=global_step) tf2.summary.scalar( 'rltaglogits/op_indices_1/2', values[5], step=global_step) tf2.summary.scalar( 'rltaglogits/filters_indices_0/0', values[6], step=global_step) tf2.summary.scalar( 'rltaglogits/filters_indices_0/1', values[7], step=global_step) self.evaluate(writer.init()) summary_op = tf.summary.all_v2_summary_ops() flush_op = writer.flush() with self.cached_session() as sess: sess.run(summary_op, { values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], global_step: 31, }) sess.run(flush_op) sess.run(summary_op, { values: [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0], global_step: 42, }) sess.run(flush_op) # Now check that read_tag_logits() processes the events correctly. self.assertAllClose( analyze_mobile_search_lib.read_tag_logits(tempdir), { 31: { 'op_indices': [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 'filters_indices': [[7.0, 8.0]], }, 42: { 'op_indices': [[9.0, 10.0, 11.0], [12.0, 13.0, 14.0]], 'filters_indices': [[15.0, 16.0]], }, })
def test_read_tag_logits_with_invalid_entry(self): # Create fake logs for the read_tag_logits() function to consume. tempdir = self.get_temp_dir() writer = tf2.summary.create_file_writer(tempdir, max_queue=0) with writer.as_default(): tf2.summary.scalar('rltaglogits/op_indices_0/0', 1.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_0/1', 2.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_0/2', 3.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/0', 4.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/1', 5.0, step=42) tf2.summary.scalar('rltaglogits/op_indices_1/2', 6.0, step=42) # 'rltaglogits/filters_indices_0/1' is missing from the logs. tf2.summary.scalar('rltaglogits/filters_indices_0/1', 8.0, step=42) self.evaluate(writer.init()) self.evaluate(tf.summary.all_v2_summary_ops()) self.evaluate(writer.flush()) # Try to read the events from file. The events from Step 42 should be # skipped, since some of the data is incomplete. self.assertEmpty(analyze_mobile_search_lib.read_tag_logits(tempdir))