コード例 #1
0
def throw_if_file_access_not_allowed(file_path, logdir, allowed_dir=None):
    """Throws an error if a file cannot be loaded for inference.

  Args:
    file_path: A file path.
    logdir: The path to the logdir of the TensorBoard context.
    allowed_dir: An optional path to allow loading files from, outside of
    the logdir.

  Raises:
    InvalidUserInputError: If the file is not in the logdir and is not globally
        readable.
  """
    file_paths = filepath_to_filepath_list(file_path)
    if not file_paths:
        raise common_utils.InvalidUserInputError(file_path +
                                                 ' contains no files')

    for path in file_paths:
        # Check if the file is inside the logdir or allowed dir.
        if not (path_is_parent(logdir, path) or
                (allowed_dir and path_is_parent(allowed_dir, path))):
            raise common_utils.InvalidUserInputError(
                path + ' is not inside the TensorBoard logdir or '
                '--whatif-data-dir argument directory.')
コード例 #2
0
def mutant_charts_for_feature(example_protos, feature_name, serving_bundles,
                              viz_params):
    """Returns JSON formatted for rendering all charts for a feature.

  Args:
    example_proto: The example protos to mutate.
    feature_name: The string feature name to mutate.
    serving_bundles: One `ServingBundle` object per model, that contains the
      information to make the serving request.
    viz_params: A `VizParams` object that contains the UI state of the request.

  Raises:
    InvalidUserInputError if `viz_params.feature_index_pattern` requests out of
    range indices for `feature_name` within `example_proto`.

  Returns:
    A JSON-able dict for rendering a single mutant chart.  parsed in
    `tf-inference-dashboard.html`.
    {
      'chartType': 'numeric', # oneof('numeric', 'categorical')
      'data': [A list of data] # parseable by vz-line-chart or vz-bar-chart
    }
  """
    def chart_for_index(index_to_mutate):
        mutant_features, mutant_examples = make_mutant_tuples(
            example_protos, original_feature, index_to_mutate, viz_params)

        charts = []
        for serving_bundle in serving_bundles:
            (inference_result_proto,
             _) = run_inference(mutant_examples, serving_bundle)
            charts.append(
                make_json_formatted_for_single_chart(mutant_features,
                                                     inference_result_proto,
                                                     index_to_mutate))
        return charts

    try:
        original_feature = parse_original_feature_from_example(
            example_protos[0], feature_name)
    except ValueError as e:
        return {'chartType': 'categorical', 'data': []}

    indices_to_mutate = viz_params.feature_indices or range(
        original_feature.length)
    chart_type = ('categorical' if original_feature.feature_type
                  == 'bytes_list' else 'numeric')

    try:
        return {
            'chartType':
            chart_type,
            'data': [
                chart_for_index(index_to_mutate)
                for index_to_mutate in indices_to_mutate
            ]
        }
    except IndexError as e:
        raise common_utils.InvalidUserInputError(e)
コード例 #3
0
    def _parse_request_arguments(self, request):
        """Parses comma separated request arguments

    Args:
      request: A request that should contain 'inference_address', 'model_name',
        'model_version', 'model_signature'.

    Returns:
      A tuple of lists for model parameters
    """
        inference_addresses = request.args.get('inference_address').split(',')
        model_names = request.args.get('model_name').split(',')
        model_versions = request.args.get('model_version').split(',')
        model_signatures = request.args.get('model_signature').split(',')
        if len(model_names) != len(inference_addresses):
            raise common_utils.InvalidUserInputError(
                'Every model should have a ' + 'name and address.')
        return inference_addresses, model_names, model_versions, model_signatures
コード例 #4
0
 def to_int(x):
     try:
         return int(x)
     except (ValueError, TypeError) as e:
         raise common_utils.InvalidUserInputError(e)
コード例 #5
0
def example_protos_from_path(path,
                             num_examples=10,
                             start_index=0,
                             parse_examples=True,
                             sampling_odds=1,
                             example_class=tf.train.Example):
    """Returns a number of examples from the provided path.

  Args:
    path: A string path to the examples.
    num_examples: The maximum number of examples to return from the path.
    parse_examples: If true then parses the serialized proto from the path into
        proto objects. Defaults to True.
    sampling_odds: Odds of loading an example, used for sampling. When >= 1
        (the default), then all examples are loaded.
    example_class: tf.train.Example or tf.train.SequenceExample class to load.
        Defaults to tf.train.Example.

  Returns:
    A list of Example protos or serialized proto strings at the path.

  Raises:
    InvalidUserInputError: If examples cannot be procured from the path.
  """
    def append_examples_from_iterable(iterable, examples):
        for value in iterable:
            if sampling_odds >= 1 or random.random() < sampling_odds:
                examples.append(
                    example_class.FromString(value
                                             ) if parse_examples else value)
                if len(examples) >= num_examples:
                    return

    examples = []

    if path.endswith('.csv'):

        def are_floats(values):
            for value in values:
                try:
                    float(value)
                except ValueError:
                    return False
            return True

        csv.register_dialect('CsvDialect', skipinitialspace=True)
        rows = csv.DictReader(open(path), dialect='CsvDialect')
        for row in rows:
            if sampling_odds < 1 and random.random() > sampling_odds:
                continue
            example = tf.train.Example()
            for col in row.keys():
                # Parse out individual values from vertical-bar-delimited lists
                values = [val.strip() for val in row[col].split('|')]
                if are_floats(values):
                    example.features.feature[col].float_list.value.extend(
                        [float(val) for val in values])
                else:
                    example.features.feature[col].bytes_list.value.extend(
                        [val.encode('utf-8') for val in values])
            examples.append(
                example if parse_examples else example.SerializeToString())
            if len(examples) >= num_examples:
                break
        return examples

    filenames = filepath_to_filepath_list(path)
    compression_types = [
        '',  # no compression (distinct from `None`!)
        'GZIP',
        'ZLIB',
    ]
    current_compression_idx = 0
    current_file_index = 0
    while (current_file_index < len(filenames)
           and current_compression_idx < len(compression_types)):
        try:
            record_iterator = tf.compat.v1.python_io.tf_record_iterator(
                path=filenames[current_file_index],
                options=tf.io.TFRecordOptions(
                    compression_types[current_compression_idx]))
            append_examples_from_iterable(record_iterator, examples)
            current_file_index += 1
            if len(examples) >= num_examples:
                break
        except tf.errors.DataLossError:
            current_compression_idx += 1
        except (IOError, tf.errors.NotFoundError) as e:
            raise common_utils.InvalidUserInputError(e)

    if examples:
        return examples
    else:
        raise common_utils.InvalidUserInputError(
            'No examples found at ' + path +
            '. Valid formats are TFRecord files.')