def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    indices = search_space_utils.parse_list(FLAGS.indices, int)
    ssd = FLAGS.ssd
    cost = mobile_cost_model.estimate_cost(indices, ssd)
    print('estimated cost: {:f}'.format(cost))
 def test_estimate_cost_integration_test(self):
     indices = [
         0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 1, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1,
         0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 2, 1, 0, 0, 0
     ]
     cost = mobile_cost_model.estimate_cost(indices, 'proxylessnas_search')
     self.assertNear(cost, 84.0, err=1.0)
def _scan_directory(directory, output_format, ssd):
    """Scan a directory for log files and write the final model to stdout."""
    if output_format == _OUTPUT_FORMAT_LINES:
        print('directory =', directory)

    model_spec_filename = os.path.join(directory, 'model_spec.json')
    if not tf.io.gfile.exists(model_spec_filename):
        print('file {} not found; skipping'.format(model_spec_filename))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    with tf.io.gfile.GFile(model_spec_filename, 'r') as handle:
        model_spec = schema_io.deserialize(handle.read())

    paths = []
    oneofs = dict()

    def populate_oneofs(path, oneof):
        paths.append(path)
        oneofs[path] = oneof

    schema.map_oneofs_with_paths(populate_oneofs, model_spec)

    all_path_logits = analyze_mobile_search_lib.read_path_logits(directory)
    if not all_path_logits:
        print(
            'event data missing from directory {}; skipping'.format(directory))
        if output_format == _OUTPUT_FORMAT_LINES:
            print()
        return

    global_step = max(all_path_logits)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('global_step = {:d}'.format(global_step))

    all_path_logit_keys = six.viewkeys(all_path_logits[global_step])
    oneof_keys = six.viewkeys(oneofs)
    if all_path_logit_keys != oneof_keys:
        raise ValueError(
            'OneOf key mismatch. Present in event files but not in model_spec: {}. '
            'Present in model_spec but not in event files: {}'.format(
                all_path_logit_keys - oneof_keys,
                oneof_keys - all_path_logit_keys))

    indices = []
    for path in paths:
        index = np.argmax(all_path_logits[global_step][path])
        indices.append(index)

    indices_str = ':'.join(map(str, indices))
    if output_format == _OUTPUT_FORMAT_LINES:
        print('indices = {:s}'.format(indices_str))

    cost_model_time = mobile_cost_model.estimate_cost(indices, ssd)
    if output_format == _OUTPUT_FORMAT_LINES:
        print('cost_model = {:f}'.format(cost_model_time))

    if output_format == _OUTPUT_FORMAT_LINES:
        print()
    elif output_format == _OUTPUT_FORMAT_CSV:
        fields = [indices_str, global_step, directory, cost_model_time]
        print(','.join(map(str, fields)))