Exemplo n.º 1
0
    def test_map_oneofs_with_paths(self):
        structure = {
            'foo': [
                schema.OneOf([1, 2], 'tag1'),
                schema.OneOf([3, 4, 5], 'tag2'),
            ]
        }

        all_paths = []
        all_oneofs = []

        def visit(path, oneof):
            all_paths.append(path)
            all_oneofs.append(oneof)
            return schema.OneOf([x * 10 for x in oneof.choices], oneof.tag)

        self.assertEqual(
            schema.map_oneofs_with_paths(visit, structure), {
                'foo': [
                    schema.OneOf([10, 20], 'tag1'),
                    schema.OneOf([30, 40, 50], 'tag2'),
                ]
            })
        self.assertEqual(all_paths, [
            'foo/0',
            'foo/1',
        ])
        self.assertEqual(all_oneofs, [
            schema.OneOf([1, 2], 'tag1'),
            schema.OneOf([3, 4, 5], 'tag2'),
        ])
Exemplo n.º 2
0
def _with_constant_masks(indices, model_spec):
  """Assign constant one-hot masks to the OneOf nodes in model_spec."""
  _assert_correct_oneof_count(indices, model_spec)

  # We use an object with static member fields to maintain internal state so
  # that the elements inside can be updated within a nested function.
  class State(object):
    position = 0  # Current position within 'indices'

  def update_mask(path, oneof):
    """Add a one-hot mask to 'oneof' whose value is derived from 'indices'."""
    index = indices[State.position]
    State.position += 1

    if index < 0 or index >= len(oneof.choices):
      raise ValueError(
          'Invalid index: {:d} for path: {:s} with {:d} choices'.format(
              index, path, len(oneof.choices)))

    mask = tf.one_hot(index, len(oneof.choices))
    return schema.OneOf(oneof.choices, oneof.tag, mask)

  return schema.map_oneofs_with_paths(update_mask, model_spec)
def _scan_directory(directory, output_format, ssd):
    """Scan a directory for log files and write the final model to stdout."""
    if output_format == _OUTPUT_FORMAT_LINES:
        print('directory =', directory)

    model_spec_filename = os.path.join(directory, 'model_spec.json')
    if not tf.io.gfile.exists(model_spec_filename):
        print('file {} not found; skipping'.format(model_spec_filename))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    with tf.io.gfile.GFile(model_spec_filename, 'r') as handle:
        model_spec = schema_io.deserialize(handle.read())

    paths = []
    oneofs = dict()

    def populate_oneofs(path, oneof):
        paths.append(path)
        oneofs[path] = oneof

    schema.map_oneofs_with_paths(populate_oneofs, model_spec)

    all_path_logits = analyze_mobile_search_lib.read_path_logits(directory)
    if not all_path_logits:
        print(
            'event data missing from directory {}; skipping'.format(directory))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    global_step = max(all_path_logits)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('global_step = {:d}'.format(global_step))

    all_path_logit_keys = six.viewkeys(all_path_logits[global_step])
    oneof_keys = six.viewkeys(oneofs)
    if all_path_logit_keys != oneof_keys:
        raise ValueError(
            'OneOf key mismatch. Present in event files but not in model_spec: {}. '
            'Present in model_spec but not in event files: {}'.format(
                all_path_logit_keys - oneof_keys,
                oneof_keys - all_path_logit_keys))

    indices = []
    for path in paths:
        index = np.argmax(all_path_logits[global_step][path])
        indices.append(index)

    indices_str = ':'.join(map(str, indices))
    if output_format == _OUTPUT_FORMAT_LINES:
        print('indices = {:s}'.format(indices_str))

    cost_model_time = mobile_cost_model.estimate_cost(indices, ssd)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('cost_model = {:f}'.format(cost_model_time))

    if output_format == _OUTPUT_FORMAT_LINES:
        print()
    elif output_format == _OUTPUT_FORMAT_CSV:
        fields = [indices_str, global_step, directory, cost_model_time]
        print(','.join(map(str, fields)))