def test_read_path_logits(self):
    # Create fake logs for the read_path_logits() function to consume.
    tempdir = self.get_temp_dir()
    writer = tf2.summary.create_file_writer(tempdir, max_queue=0)
    with writer.as_default():
      # Events matching the path logits pattern
      tf2.summary.scalar('rlpathlogits/filters/0', 1.0, step=42)
      tf2.summary.scalar('rlpathlogits/filters/1', 2.0, step=42)
      tf2.summary.scalar('rlpathlogits/filters/2', 3.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/0', 4.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/1', 5.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/2', 6.0, step=42)
      # Events not matching any pattern.
      tf2.summary.scalar('global_step/sec', 10.0, step=42)

    self.evaluate(writer.init())
    self.evaluate(tf.summary.all_v2_summary_ops())
    self.evaluate(writer.flush())

    # Try to read the events from file.
    self.assertAllClose(
        analyze_mobile_search_lib.read_path_logits(tempdir),
        {
            42: {
                'filters': [1.0, 2.0, 3.0],
                'layers/0/choices': [4.0, 5.0, 6.0],
            }
        })
  def test_read_path_logits_with_invalid_entry(self):
    # Create fake logs for the read_path_logits() function to consume.
    tempdir = self.get_temp_dir()
    writer = tf2.summary.create_file_writer(tempdir, max_queue=0)
    with writer.as_default():
      # Events matching the path logits pattern
      tf2.summary.scalar('rlpathlogits/filters/0', 1.0, step=42)
      tf2.summary.scalar('rlpathlogits/filters/2', 3.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/0', 4.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/1', 5.0, step=42)
      tf2.summary.scalar('rlpathlogits/layers/0/choices/2', 6.0, step=42)
      # 'rlpathlogits/filters/1' is missing from the logs

    self.evaluate(writer.init())
    self.evaluate(tf.summary.all_v2_summary_ops())
    self.evaluate(writer.flush())

    # Try to read the events from file. The events from Step 42 should be
    # skipped, since some of the data is incomplete.
    self.assertEmpty(analyze_mobile_search_lib.read_path_logits(tempdir))
def _scan_directory(directory, output_format, ssd):
    """Scan a directory for log files and write the final model to stdout."""
    if output_format == _OUTPUT_FORMAT_LINES:
        print('directory =', directory)

    model_spec_filename = os.path.join(directory, 'model_spec.json')
    if not tf.io.gfile.exists(model_spec_filename):
        print('file {} not found; skipping'.format(model_spec_filename))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    with tf.io.gfile.GFile(model_spec_filename, 'r') as handle:
        model_spec = schema_io.deserialize(handle.read())

    paths = []
    oneofs = dict()

    def populate_oneofs(path, oneof):
        paths.append(path)
        oneofs[path] = oneof

    schema.map_oneofs_with_paths(populate_oneofs, model_spec)

    all_path_logits = analyze_mobile_search_lib.read_path_logits(directory)
    if not all_path_logits:
        print(
            'event data missing from directory {}; skipping'.format(directory))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    global_step = max(all_path_logits)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('global_step = {:d}'.format(global_step))

    all_path_logit_keys = six.viewkeys(all_path_logits[global_step])
    oneof_keys = six.viewkeys(oneofs)
    if all_path_logit_keys != oneof_keys:
        raise ValueError(
            'OneOf key mismatch. Present in event files but not in model_spec: {}. '
            'Present in model_spec but not in event files: {}'.format(
                all_path_logit_keys - oneof_keys,
                oneof_keys - all_path_logit_keys))

    indices = []
    for path in paths:
        index = np.argmax(all_path_logits[global_step][path])
        indices.append(index)

    indices_str = ':'.join(map(str, indices))
    if output_format == _OUTPUT_FORMAT_LINES:
        print('indices = {:s}'.format(indices_str))

    cost_model_time = mobile_cost_model.estimate_cost(indices, ssd)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('cost_model = {:f}'.format(cost_model_time))

    if output_format == _OUTPUT_FORMAT_LINES:
        print()
    elif output_format == _OUTPUT_FORMAT_CSV:
        fields = [indices_str, global_step, directory, cost_model_time]
        print(','.join(map(str, fields)))