예제 #1
0
def _parse_operation(s: Text) -> delay_model_pb2.Operation:
  """Parses a text proto representation of an Operation."""
  return text_format.Parse(s, delay_model_pb2.Operation())
예제 #2
0
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.extractors import sql_slice_key_extractor
from tensorflow_model_analysis.proto import config_pb2
from tfx_bsl.tfxio import tf_example_record

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2

_SCHEMA = text_format.Parse(
    """
        feature {
          name: "fixed_int"
          type: INT
        }
        feature {
          name: "fixed_float"
          type: FLOAT
        }
        feature {
          name: "fixed_string"
          type: BYTES
        }
        """, schema_pb2.Schema())


class SqlSliceKeyExtractorTest(testutil.TensorflowModelAnalysisTest):

  def testSqlSliceKeyExtractor(self):
    eval_config = config_pb2.EvalConfig(slicing_specs=[
        config_pb2.SlicingSpec(slice_keys_sql="""
        SELECT
    text_format.Parse(
        """
context {
  features {
    feature {
      key: "ctx.int"  # dot in the feature name is intended.
      value {
        int64_list {
          value: [1, 2]
        }
      }
    }
    feature {
      key: "ctx.float"
      value {
        float_list {
          value: [1.0, 2.0]
        }
      }
    }
    feature {
      key: "ctx.bytes"
      value {
        bytes_list {
          value: []
        }
      }
    }
  }
}
examples {
  features {
    feature {
      key: "example_int"
      value {
        int64_list {
          value: [11]
        }
      }
    }
    feature {
      key: "example_float"
      value {
        float_list {
          value: [11.0, 12.0]
        }
      }
    }
    feature {
      key: "example_bytes"
      value {
        bytes_list {
          value: ["u", "v"]
        }
      }
    }
  }
}
examples {
  features {
    feature {
      key: "example_int"
      value {
        int64_list {
          value: [22]
        }
      }
    }
    # example_float is not present.
    feature {
      key: "example_bytes"
      value {
        bytes_list {
          value: ["w"]
        }
      }
    }
  }
}
""", input_pb2.ExampleListWithContext()).SerializeToString(),
 def testUncertaintyValuedMetrics(self):
     slice_key = _make_slice_key()
     slice_metrics = {
         'one_dim':
         types.ValueWithTDistribution(2.0, 1.0, 3, 2.0),
         'nans':
         types.ValueWithTDistribution(float('nan'), float('nan'), -1,
                                      float('nan')),
     }
     expected_metrics_for_slice = text_format.Parse(
         """
     slice_key {}
     metrics {
       key: "one_dim"
       value {
         bounded_value {
           value {
             value: 2.0
           }
           lower_bound {
             value: -1.1824463
           }
           upper_bound {
             value: 5.1824463
           }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound {
             value: -1.1824463
           }
           upper_bound {
             value: 5.1824463
           }
           t_distribution_value {
             sample_mean {
               value: 2.0
             }
             sample_standard_deviation {
               value: 1.0
             }
             sample_degrees_of_freedom {
               value: 3
             }
             unsampled_value {
               value: 2.0
             }
           }
         }
       }
     }
     metrics {
       key: "nans"
       value {
         bounded_value {
           value {
             value: nan
           }
           lower_bound {
             value: nan
           }
           upper_bound {
             value: nan
           }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound {
             value: nan
           }
           upper_bound {
             value: nan
           }
           t_distribution_value {
             sample_mean {
               value: nan
             }
             sample_standard_deviation {
               value: nan
             }
             sample_degrees_of_freedom {
               value: -1
             }
             unsampled_value {
               value: nan
             }
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
         (slice_key, slice_metrics), [])
     self.assertProtoEquals(expected_metrics_for_slice, got)
    def testConvertSlicePlotsToProto(self):
        slice_key = _make_slice_key('fruit', 'apple')
        plot_key = metric_types.PlotKey(name='calibration_plot',
                                        output_name='output_name')
        calibration_plot = text_format.Parse(
            """
        buckets {
          lower_threshold_inclusive: -inf
          upper_threshold_exclusive: 0.0
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
        buckets {
          lower_threshold_inclusive: 0.0
          upper_threshold_exclusive: 0.5
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 1.0 }
          total_weighted_refined_prediction { value: 0.3 }
        }
        buckets {
          lower_threshold_inclusive: 0.5
          upper_threshold_exclusive: 1.0
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.7 }
        }
        buckets {
          lower_threshold_inclusive: 1.0
          upper_threshold_exclusive: inf
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
     """, metrics_for_slice_pb2.CalibrationHistogramBuckets())

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {
        single_slice_keys {
          column: 'fruit'
          bytes_value: 'apple'
        }
      }
      plot_keys_and_values {
        key {
          output_name: "output_name"
        }
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        got = metrics_plots_and_validations_writer.convert_slice_plots_to_proto(
            (slice_key, {
                plot_key: calibration_plot
            }), None)
        self.assertProtoEquals(expected_plots_for_slice, got)
    def testWriteMetricsAndPlots(self, output_file_format):
        metrics_file = os.path.join(self._getTempDir(), 'metrics')
        plots_file = os.path.join(self._getTempDir(), 'plots')
        temp_eval_export_dir = os.path.join(self._getTempDir(),
                                            'eval_export_dir')

        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        eval_config = config.EvalConfig(
            model_specs=[config.ModelSpec()],
            options=config.Options(
                disabled_outputs={'values': ['eval_config.json']}))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.calibration_plot_and_prediction_histogram(
                    num_buckets=2)
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]
        evaluators = [
            metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(
                eval_shared_model)
        ]
        output_paths = {
            constants.METRICS_KEY: metrics_file,
            constants.PLOTS_KEY: plots_file
        }
        writers = [
            metrics_plots_and_validations_writer.
            MetricsPlotsAndValidationsWriter(
                output_paths,
                eval_config=eval_config,
                add_metrics_callbacks=eval_shared_model.add_metrics_callbacks,
                output_file_format=output_file_format)
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=1.0, label=1.0)

            # pylint: disable=no-value-for-parameter
            _ = (pipeline
                 | 'Create' >> beam.Create([
                     example1.SerializeToString(),
                     example2.SerializeToString(),
                 ])
                 | 'ExtractEvaluateAndWriteResults' >>
                 model_eval_lib.ExtractEvaluateAndWriteResults(
                     eval_config=eval_config,
                     eval_shared_model=eval_shared_model,
                     extractors=extractors,
                     evaluators=evaluators,
                     writers=writers))
            # pylint: enable=no-value-for-parameter

        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        metric_records = list(
            metrics_plots_and_validations_writer.load_and_deserialize_metrics(
                metrics_file))
        self.assertLen(metric_records, 1, 'metrics: %s' % metric_records)
        self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        plot_records = list(
            metrics_plots_and_validations_writer.load_and_deserialize_plots(
                plots_file))
        self.assertLen(plot_records, 1, 'plots: %s' % plot_records)
        self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
    def testConvertSliceMetricsToProtoMetricsRanges(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8),
            metric_keys.AUPRC: 0.1,
            metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
            metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
            metric_keys.AUC: 0.2,
            metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
            metric_keys.upper_bound_key(metric_keys.AUC): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            bounded_value {
              value {
                value: 0.8
              }
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              t_distribution_value {
                sample_mean {
                  value: 0.8
                }
                sample_standard_deviation {
                  value: 0.1
                }
                sample_degrees_of_freedom {
                  value: 9
                }
                unsampled_value {
                  value: 0.8
                }
              }
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(expected_metrics_for_slice, got)
# 该脚本用于更新tensorflow/serving中的models.config
import grpc
from google.protobuf import text_format
from tensorflow_serving.apis import model_service_pb2_grpc, model_management_pb2
from tensorflow_serving.config import model_server_config_pb2
from tensorflow_serving.sources.storage_path.file_system_storage_path_source_pb2 import FileSystemStoragePathSourceConfig

# models.config所在路径
model_config_file_path = "./models.config"
with open(model_config_file_path, 'r+') as f:
    config_ini = f.read()

request = model_management_pb2.ReloadConfigRequest()
model_server_config = model_server_config_pb2.ModelServerConfig()
config_list = model_server_config_pb2.ModelConfigList()
model_server_config = text_format.Parse(text=config_ini,
                                        message=model_server_config)

# Create a config to add to the list of served models
one_config = config_list.config.add()
one_config.name = "lmj"
one_config.base_path = "/models/lmj"
one_config.model_platform = "tensorflow"
servable_version_policy = FileSystemStoragePathSourceConfig(
).ServableVersionPolicy()
one_config.model_version_policy.all.CopyFrom(servable_version_policy.All())

model_server_config.model_config_list.MergeFrom(config_list)
request.config.CopyFrom(model_server_config)

# 服务地址:192.168.1.168:8510, 其中8510对应tensorflow/serving的8500端口
channel = grpc.insecure_channel('192.168.1.193:8510')
    def testUnbatchExtractor(self):
        model_spec = config.ModelSpec(label_key='label',
                                      example_weight_key='example_weight')
        eval_config = config.EvalConfig(model_specs=[model_spec])
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor()

        schema = text_format.Parse(
            """
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "fixed_int"
          type: INT
        }
        feature {
          name: "fixed_float"
          type: FLOAT
        }
        feature {
          name: "fixed_string"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        examples = [
            self._makeExample(label=1.0,
                              example_weight=0.5,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(label=0.0,
                              example_weight=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string2'),
            self._makeExample(label=0.0,
                              example_weight=1.0,
                              fixed_int=2,
                              fixed_float=0.0,
                              fixed_string='fixed_string3')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | unbatch_inputs_extractor.stage_name >>
                unbatch_inputs_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 3)
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[0][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string1']))
                    self.assertAlmostEqual(got[0][constants.LABELS_KEY],
                                           np.array([1.0]))
                    self.assertAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[1][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string2']))
                    self.assertAlmostEqual(got[1][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]))
                    self.assertDictElementsAlmostEqual(
                        got[2][constants.FEATURES_KEY], {
                            'fixed_int': np.array([2]),
                            'fixed_float': np.array([0.0]),
                        })
                    self.assertEqual(
                        got[2][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string3']))
                    self.assertAlmostEqual(got[2][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def _get_csv_test(self, delimiter=',', with_header=False):
        fields = [['feature1', 'feature2'], ['1.0', 'aa'], ['2.0', 'bb'],
                  ['3.0', 'cc'], ['4.0', 'dd'], ['5.0', 'ee'], ['6.0', 'ff'],
                  ['7.0', 'gg'], ['', '']]
        records = []
        for row in fields:
            records.append(delimiter.join(row))

        expected_result = text_format.Parse(
            """
    datasets {
  num_examples: 8
  features {
    path {
      step: "feature1"
    }
    type: FLOAT
    num_stats {
      common_stats {
        num_non_missing: 7
        num_missing: 1
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 3.5
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 3.5
          }
          type: QUANTILES
        }
        tot_num_values: 7
      }
      mean: 4.0
      std_dev: 2.0
      min: 1.0
      max: 7.0
      median: 4.0
      histograms {
        buckets {
          low_value: 1.0
          high_value: 4.0
          sample_count: 3.01
        }
        buckets {
          low_value: 4.0
          high_value: 7.0
          sample_count: 3.99
        }
      }
      histograms {
        buckets {
          low_value: 1.0
          high_value: 4.0
          sample_count: 3.5
        }
        buckets {
          low_value: 4.0
          high_value: 7.0
          sample_count: 3.5
        }
        type: QUANTILES
      }
    }
  }
  features {
    path {
      step: "feature2"
    }
    type: STRING
    string_stats {
      common_stats {
        num_non_missing: 7
        num_missing: 1
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 3.5
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 3.5
          }
          type: QUANTILES
        }
        tot_num_values: 7
      }
      unique: 7
      top_values {
        value: "gg"
        frequency: 1.0
      }
      top_values {
        value: "ff"
        frequency: 1.0
      }
      avg_length: 2.0
      rank_histogram {
        buckets {
          label: "gg"
          sample_count: 1.0
        }
        buckets {
          low_rank: 1
          high_rank: 1
          label: "ff"
          sample_count: 1.0
        }
      }
    }
  }
    }
    """, statistics_pb2.DatasetFeatureStatisticsList())

        if with_header:
            return (records, None, expected_result)
        return (records[1:], records[0].split(delimiter), expected_result)
    def test_stats_gen_with_csv_with_schema(self):
        records = ['feature1', '1']
        input_data_path = self._write_records_to_csv(records,
                                                     self._get_temp_dir(),
                                                     'input_data.csv')
        schema = text_format.Parse(
            """
        feature { name: "feature1" type: BYTES }
        """, schema_pb2.Schema())

        expected_result = text_format.Parse(
            """
    datasets {
  num_examples: 1
  features {
    path {
      step: "feature1"
    }
    type: STRING
    string_stats {
      common_stats {
        num_non_missing: 1
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 0.5
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 0.5
          }
          type: QUANTILES
        }
        tot_num_values: 1
      }
      unique: 1
      top_values {
        value: "1"
        frequency: 1.0
      }
      avg_length: 1.0
      rank_histogram {
        buckets {
          label: "1"
          sample_count: 1.0
        }
      }
    }
  }
    }
    """, statistics_pb2.DatasetFeatureStatisticsList())

        self._default_stats_options.schema = schema
        self._default_stats_options.infer_type_from_schema = True
        result = stats_gen_lib.generate_statistics_from_csv(
            data_location=input_data_path,
            delimiter=',',
            stats_options=self._default_stats_options)
        compare_fn = test_util.make_dataset_feature_stats_list_proto_equal_fn(
            self, expected_result)
        compare_fn([result])
    def test_stats_gen_with_tfrecords_of_tfexamples(self, compression_type):
        examples = [
            self._make_example({
                'a': ('float', [1.0, 2.0]),
                'b': ('bytes', [b'a', b'b', b'c', b'e'])
            }),
            self._make_example({
                'a': ('float', [3.0, 4.0, float('nan'), 5.0]),
                'b': ('bytes', [b'a', b'c', b'd', b'a'])
            }),
            self._make_example({
                'a': ('float', [1.0]),
                'b': ('bytes', [b'a', b'b', b'c', b'd'])
            })
        ]
        tf_compression_lookup = {
            CompressionTypes.AUTO:
            tf.compat.v1.python_io.TFRecordCompressionType.NONE,
            CompressionTypes.GZIP:
            tf.compat.v1.python_io.TFRecordCompressionType.GZIP
        }
        input_data_path = self._write_tfexamples_to_tfrecords(
            examples, tf_compression_lookup[compression_type])

        expected_result = text_format.Parse(
            """
    datasets {
      num_examples: 3
      features {
        path {
          step: "a"
        }
        type: FLOAT
        num_stats {
          common_stats {
            num_non_missing: 3
            num_missing: 0
            min_num_values: 1
            max_num_values: 4
            avg_num_values: 2.33333333
            tot_num_values: 7
            num_values_histogram {
              buckets {
                low_value: 1.0
                high_value: 2.0
                sample_count: 1.5
              }
              buckets {
                low_value: 2.0
                high_value: 4.0
                sample_count: 1.5
              }
              type: QUANTILES
            }
          }
          mean: 2.66666666
          std_dev: 1.49071198
          num_zeros: 0
          min: 1.0
          max: 5.0
          median: 3.0
          histograms {
            num_nan: 1
            buckets {
              low_value: 1.0
              high_value: 3.0
              sample_count: 3.0
            }
            buckets {
              low_value: 3.0
              high_value: 5.0
              sample_count: 3.0
            }
            type: STANDARD
          }
          histograms {
            num_nan: 1
            buckets {
              low_value: 1.0
              high_value: 3.0
              sample_count: 3.0
            }
            buckets {
              low_value: 3.0
              high_value: 5.0
              sample_count: 3.0
            }
            type: QUANTILES
          }
        }
      }
      features {
        path {
          step: "b"
        }
        type: STRING
        string_stats {
          common_stats {
            num_non_missing: 3
            min_num_values: 4
            max_num_values: 4
            avg_num_values: 4.0
            tot_num_values: 12
            num_values_histogram {
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.5
              }
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.5
              }
              type: QUANTILES
            }
          }
          unique: 5
          top_values {
            value: "a"
            frequency: 4.0
          }
          top_values {
            value: "c"
            frequency: 3.0
          }
          avg_length: 1.0
          rank_histogram {
            buckets {
              low_rank: 0
              high_rank: 0
              label: "a"
              sample_count: 4.0
            }
            buckets {
              low_rank: 1
              high_rank: 1
              label: "c"
              sample_count: 3.0
            }
          }
        }
      }
    }
    """, statistics_pb2.DatasetFeatureStatisticsList())

        result = stats_gen_lib.generate_statistics_from_tfrecord(
            data_location=input_data_path,
            stats_options=self._default_stats_options,
            compression_type=compression_type)
        compare_fn = test_util.make_dataset_feature_stats_list_proto_equal_fn(
            self, expected_result)
        compare_fn([result])
예제 #13
0
def _ReadProto(proto, path):
  with open(path, 'r', encoding='utf-8') as f:
    proto = text_format.Parse(f.read(), proto)
    return proto
예제 #14
0
def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored):
    job_conf = text_format.Parse(job_conf_str, job_conf_pb.JobConfigProto())
    parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
    return compiler.MakeInitialScope(
        job_conf, parallel_conf.device_tag, list(parallel_conf.device_name), is_mirrored
    ).symbol_id
예제 #15
0
def read_project(f):
    return text_format.Parse(f.read(), config_pb2.Project())
예제 #16
0
 def test_find_significant_slices(self):
   metrics = [
       text_format.Parse(
           """
       slice_key {
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.8 }
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             t_distribution_value {
               sample_mean { value: 0.8 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.8 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 1500 }
             lower_bound { value: 1500 }
             upper_bound { value: 1500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 1500 }
             upper_bound { value: 1500 }
             t_distribution_value {
               sample_mean { value: 1500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 1500 }
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'age'
           bytes_value: '[1.0, 6.0)'
         }
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.4 }
             lower_bound { value: 0.3737843 }
             upper_bound { value: 0.6262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.3737843 }
             upper_bound { value: 0.6262157 }
             t_distribution_value {
               sample_mean { value: 0.4 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.4 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 500 }
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             t_distribution_value {
               sample_mean { value: 500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 500 }
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'age'
           bytes_value: '[6.0, 12.0)'
         }
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.79 }
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             t_distribution_value {
               sample_mean { value: 0.79 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.79 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 500 }
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             t_distribution_value {
               sample_mean { value: 500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 500}
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'age'
           bytes_value: '[12.0, 18.0)'
         }
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.9 }
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             t_distribution_value {
               sample_mean { value: 0.9 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.9 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 500 }
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             t_distribution_value {
               sample_mean { value: 500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 500}
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'country'
           bytes_value: 'USA'
         }
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.9 }
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             t_distribution_value {
               sample_mean { value: 0.9 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.9 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 500 }
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             t_distribution_value {
               sample_mean { value: 500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 500}
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'country'
           bytes_value: 'USA'
         }
         single_slice_keys {
           column: 'age'
           bytes_value: '[12.0, 18.0)'
         }
       }
       metric_keys_and_values {
         key { name: "accuracy" }
         value {
           bounded_value {
             value { value: 0.9 }
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 0.5737843 }
             upper_bound { value: 1.0262157 }
             t_distribution_value {
               sample_mean { value: 0.9 }
               sample_standard_deviation { value: 0.1 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 0.9 }
             }
           }
         }
       }
       metric_keys_and_values {
         key { name: "example_count" }
         value {
           bounded_value {
             value { value: 500 }
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound { value: 500 }
             upper_bound { value: 500 }
             t_distribution_value {
               sample_mean { value: 500 }
               sample_standard_deviation { value: 0 }
               sample_degrees_of_freedom { value: 9 }
               unsampled_value { value: 500}
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice())
   ]
   self.assertCountEqual(
       auto_slicing_util.find_significant_slices(
           metrics, metric_key='accuracy', comparison_type='LOWER'), [
               auto_slicing_util.SliceComparisonResult(
                   slice_key=(('age', '[1.0, 6.0)'),),
                   num_examples=500.0,
                   slice_metric=0.4,
                   base_metric=0.8,
                   p_value=0.0,
                   effect_size=4.0,
                   raw_slice_metrics=metrics[1])
           ])
   self.assertCountEqual(
       auto_slicing_util.find_significant_slices(
           metrics, metric_key='accuracy', comparison_type='HIGHER'), [
               auto_slicing_util.SliceComparisonResult(
                   slice_key=(('age', '[12.0, 18.0)'),),
                   num_examples=500.0,
                   slice_metric=0.9,
                   base_metric=0.8,
                   p_value=7.356017854191938e-70,
                   effect_size=0.9999999999999996,
                   raw_slice_metrics=metrics[3]),
               auto_slicing_util.SliceComparisonResult(
                   slice_key=(('country', 'USA'),),
                   num_examples=500.0,
                   slice_metric=0.9,
                   base_metric=0.8,
                   p_value=7.356017854191938e-70,
                   effect_size=0.9999999999999996,
                   raw_slice_metrics=metrics[4]),
               auto_slicing_util.SliceComparisonResult(
                   slice_key=(('age', '[12.0, 18.0)'), ('country', 'USA')),
                   num_examples=500.0,
                   slice_metric=0.9,
                   base_metric=0.8,
                   p_value=7.356017854191938e-70,
                   effect_size=0.9999999999999996,
                   raw_slice_metrics=metrics[5])
           ])
예제 #17
0
def read_config(f):
    return text_format.Parse(f.read(), config_pb2.Config())
예제 #18
0
 def test_revert_slice_keys_for_transformed_features(self):
   statistics = text_format.Parse(
       """
       datasets{
         num_examples: 1500
         features {
           path { step: 'country' }
           type: STRING
           string_stats {
             unique: 10
           }
         }
         features {
           path { step: 'age' }
           type: INT
           num_stats {
             common_stats {
               num_non_missing: 1500
               min_num_values: 1
               max_num_values: 1
             }
             min: 1
             max: 18
             histograms {
               buckets {
                 low_value: 1
                 high_value: 6.0
                 sample_count: 500
               }
               buckets {
                 low_value: 6.0
                 high_value: 12.0
                 sample_count: 500
               }
               buckets {
                 low_value: 12.0
                 high_value: 18.0
                 sample_count: 500
               }
               type: QUANTILES
             }
           }
         }
       }
       """, statistics_pb2.DatasetFeatureStatisticsList())
   metrics = [
       text_format.Parse("""
       slice_key {
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'transformed_age'
           int64_value: 1
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'transformed_age'
           int64_value: 2
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'country'
           bytes_value: 'USA'
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice())
   ]
   expected_metrics = [
       text_format.Parse("""
       slice_key {
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'age'
           bytes_value: '[1.0, 6.0)'
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'age'
           bytes_value: '[6.0, 12.0)'
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice()),
       text_format.Parse(
           """
       slice_key {
         single_slice_keys {
           column: 'country'
           bytes_value: 'USA'
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice())
   ]
   actual = auto_slicing_util.revert_slice_keys_for_transformed_features(
       metrics, statistics)
   self.assertEqual(actual, expected_metrics)
    def testConvertSliceMetricsToProtoConfusionMatrices(self):
        slice_key = _make_slice_key()

        thresholds = [0.25, 0.75, 1.00]
        matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0],
                    [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

        slice_metrics = {
            metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES: matrices,
            metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS: thresholds,
        }
        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: nan
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
            (slice_key, slice_metrics),
            [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
        self.assertProtoEquals(expected_metrics_for_slice, got)
예제 #20
0
def str_to_bond_topology(s):
    bt = dataset_pb2.BondTopology()
    text_format.Parse(s, bt)
    return bt
    def testConvertSliceMetricsToProtoFromLegacyStrings(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': 0.8,
            metric_keys.AUPRC: 0.1,
            metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
            metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
            metric_keys.AUC: 0.2,
            metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
            metric_keys.upper_bound_key(metric_keys.AUC): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(expected_metrics_for_slice, got)
예제 #22
0
def generate_parallel_module(modules: Sequence[
    module_signature_mod.ModuleGeneratorResult], module_name: str) -> str:
    """Generates a module composed of instantiated instances of the given modules.

  Each module in 'modules' is instantiated exactly once in a enclosing,
  composite module. Inputs to each instantiation are provided by inputs to the
  enclosing module. For example, if given two modules, add8_module and
  add16_module, the generated module might look like:

    module add8_module(
      input wire clk,
      input wire [7:0] op0,
      input wire [7:0] op1,
      output wire [7:0] out
    );
    // contents of module elided...
    endmodule

    module add16_module(
      input wire clk,
      input wire [15:0] op0,
      input wire [15:0] op1,
      output wire [15:0] out
    );
    // contents of module elided...
    endmodule

    module foo(
      input wire clk,
      input wire [7:0] add8_module_op0,
      input wire [7:0] add8_module_op1,
      output wire [7:0] add8_module_out,
      input wire [15:0] add16_module_op0,
      input wire [15:0] add16_module_op1,
      output wire [15:0] add16_module_out,
    );
    add8_module add8_module_inst(
      .clk(clk),
      .op0(add8_module_op0),
      .op1(add8_module_op1),
      .out(add8_module_out)
    );
    add16_module add16_module_inst(
      .clk(clk),
      .op0(add16_module_op0),
      .op1(add16_module_op1),
      .out(add16_module_out)
    );
  endmodule


  Arguments:
    modules: Modules to include instantiate.
    module_name: Name of the module containing the instantiated input modules.

  Returns:
    Verilog text containing the composite module and component modules.
  """
    module_protos = [
        text_format.Parse(m.signature.as_text_proto(),
                          module_signature_pb2.ModuleSignatureProto())
        for m in modules
    ]
    ports = ['input wire clk']
    for module in module_protos:
        for data_port in module.data_ports:
            width_str = f'[{data_port.width - 1}:0]'
            signal_name = f'{module.module_name}_{data_port.name}'
            if data_port.direction == module_signature_pb2.DIRECTION_INPUT:
                ports.append(f'input wire {width_str} {signal_name}')
            elif data_port.direction == module_signature_pb2.DIRECTION_OUTPUT:
                ports.append(f'output wire {width_str} {signal_name}')
    header = """module {module_name}(\n{ports}\n);""".format(
        module_name=module_name, ports=',\n'.join(f'  {p}' for p in ports))
    instantiations = []
    for module in module_protos:
        connections = ['.clk(clk)']
        for data_port in module.data_ports:
            connections.append(
                f'.{data_port.name}({module.module_name}_{data_port.name})')
        instantiations.append(
            '  {name} {name}_inst(\n{connections}\n  );'.format(
                name=module.module_name,
                connections=',\n'.join(f'    {c}' for c in connections)))
    return '{modules}\n\n{header}\n{instantiations}\nendmodule\n'.format(
        modules='\n\n'.join(m.verilog_text for m in modules),
        header=header,
        instantiations='\n'.join(instantiations))
    def testWriteValidationResults(self, output_file_format):
        model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
        eval_shared_model = self._build_keras_model(model_dir, mul=0)
        baseline_eval_shared_model = self._build_keras_model(baseline_dir,
                                                             mul=1)
        validations_file = os.path.join(self._getTempDir(),
                                        constants.VALIDATIONS_KEY)
        schema = text_format.Parse(
            """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "input"
              value {
                dense_tensor {
                  column_name: "input"
                  shape { dim { size: 1 } }
                }
              }
            }
          }
        }
        feature {
          name: "input"
          type: FLOAT
        }
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "extra_feature"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        examples = [
            self._makeExample(input=0.0,
                              label=1.0,
                              example_weight=1.0,
                              extra_feature='non_model_feature'),
            self._makeExample(input=1.0,
                              label=0.0,
                              example_weight=0.5,
                              extra_feature='non_model_feature'),
        ]

        slicing_specs = [
            config.SlicingSpec(),
            config.SlicingSpec(feature_keys=['slice_does_not_exist'])
        ]
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(name='candidate',
                                 label_key='label',
                                 example_weight_key='example_weight'),
                config.ModelSpec(name='baseline',
                                 label_key='label',
                                 example_weight_key='example_weight',
                                 is_baseline=True)
            ],
            slicing_specs=slicing_specs,
            metrics_specs=[
                config.MetricsSpec(
                    metrics=[
                        config.MetricConfig(
                            class_name='WeightedExampleCount',
                            per_slice_thresholds=[
                                config.PerSliceMetricThreshold(
                                    slicing_specs=slicing_specs,
                                    # 1.5 < 1, NOT OK.
                                    threshold=config.MetricThreshold(
                                        value_threshold=config.
                                        GenericValueThreshold(
                                            upper_bound={'value': 1})))
                            ]),
                        config.MetricConfig(
                            class_name='ExampleCount',
                            # 2 > 10, NOT OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    lower_bound={'value': 10}))),
                        config.MetricConfig(
                            class_name='MeanLabel',
                            # 0 > 0 and 0 > 0%?: NOT OK.
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    HIGHER_IS_BETTER,
                                    relative={'value': 0},
                                    absolute={'value': 0}))),
                        config.MetricConfig(
                            # MeanPrediction = (0+0)/(1+0.5) = 0
                            class_name='MeanPrediction',
                            # -.01 < 0 < .01, OK.
                            # Diff% = -.333/.333 = -100% < -99%, OK.
                            # Diff = 0 - .333 = -.333 < 0, OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    upper_bound={'value': .01},
                                    lower_bound={'value': -.01}),
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    LOWER_IS_BETTER,
                                    relative={'value': -.99},
                                    absolute={'value': 0})))
                    ],
                    model_names=['candidate', 'baseline']),
            ],
            options=config.Options(
                disabled_outputs={'values': ['eval_config.json']}),
        )
        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        eval_shared_models = {
            'candidate': eval_shared_model,
            'baseline': baseline_eval_shared_model
        }
        extractors = [
            batched_input_extractor.BatchedInputExtractor(eval_config),
            batched_predict_extractor_v2.BatchedPredictExtractor(
                eval_shared_model=eval_shared_models,
                eval_config=eval_config,
                tensor_adapter_config=tensor_adapter_config),
            unbatch_extractor.UnbatchExtractor(),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config, eval_shared_model=eval_shared_models)
        ]
        output_paths = {
            constants.VALIDATIONS_KEY: validations_file,
        }
        writers = [
            metrics_plots_and_validations_writer.
            MetricsPlotsAndValidationsWriter(
                output_paths,
                eval_config=eval_config,
                add_metrics_callbacks=[],
                output_file_format=output_file_format)
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            _ = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'BatchExamples' >> tfx_io.BeamSource()
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | 'ExtractEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators)
                |
                'WriteResults' >> model_eval_lib.WriteResults(writers=writers))
            # pylint: enable=no-value-for-parameter

        validation_result = (metrics_plots_and_validations_writer.
                             load_and_deserialize_validation_result(
                                 os.path.dirname(validations_file)))

        expected_validations = [
            text_format.Parse(
                """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        ]
        self.assertFalse(validation_result.validation_ok)
        self.assertLen(validation_result.metric_validations_per_slice, 1)
        self.assertCountEqual(
            expected_validations,
            validation_result.metric_validations_per_slice[0].failures)

        expected_missing_slices = [
            config.SlicingSpec(feature_keys=['slice_does_not_exist'])
        ]
        self.assertLen(validation_result.missing_slices, 1)
        self.assertCountEqual(expected_missing_slices,
                              validation_result.missing_slices)

        expected_slicing_details = [
            text_format.Parse(
                """
            slicing_spec {
            }
            num_matching_slices: 1
            """, validation_result_pb2.SlicingDetails()),
        ]
        self.assertLen(validation_result.validation_details.slicing_details, 1)
        self.assertCountEqual(
            expected_slicing_details,
            validation_result.validation_details.slicing_details)
예제 #24
0
    def testConvertAttributionsProto(self):
        attributions_for_slice = text_format.Parse(
            """
      slice_key {}
      attributions_keys_and_values {
        key {
          name: "total_attributions"
        }
        values {
          key: "feature1"
          value: {
            double_value {
              value: 1.0
            }
          }
        }
        values {
          key: "feature2"
          value: {
            double_value {
              value: 2.0
            }
          }
        }
      }
      attributions_keys_and_values {
        key {
          name: "total_attributions"
          output_name: "output1"
          sub_key: {
            class_id: { value: 1 }
          }
        }
        values {
          key: "feature1"
          value: {
            double_value {
              value: 1.0
            }
          }
        }
      }""", metrics_for_slice_pb2.AttributionsForSlice())

        got = util.convert_attributions_proto_to_dict(attributions_for_slice,
                                                      None)
        self.assertEqual(got, ((), {
            '': {
                '': {
                    'total_attributions': {
                        'feature2': {
                            'doubleValue': 2.0
                        },
                        'feature1': {
                            'doubleValue': 1.0
                        }
                    }
                }
            },
            'output1': {
                'classId:1': {
                    'total_attributions': {
                        'feature1': {
                            'doubleValue': 1.0
                        }
                    }
                }
            }
        }))
예제 #25
0
def _get_version_config(version_config_path):
  with open(version_config_path) as f:
    return text_format.Parse(f.read(), version_config_pb2.VersionConfig())
예제 #26
0
def main():

    opts = parse_command_line()

    logger.setLevel(logging._levelNames[opts.log_level.upper()])

    logger.addHandler(logging.StreamHandler())

 

    logger.info('Endpoint: {}'.format(opts.bid_endpoint))

 

    headers = {

        'Content-type': 'application/x-protobuf',

    }

    if opts.header_secret:

        headers['beeswax-auth-secret'] = opts.header_secret

 

    try:

        input_request_file = open(opts.path_to_requests_file, 'rb')

    except (IOError, OSError) as exc:

        logger.error('Could not open bid agent requests input file: {}'.format(exc))

        return -1

 

    output_file = None

    if opts.path_to_responses_file:

        try:

            output_file = open(opts.path_to_responses_file, 'wb')

        except (IOError, OSError) as exc:

            logger.error('Could not open bid agent responses output file: {}'.format(exc))

            return -1

 

    try:

        session = requests.Session()

        session.headers.update(headers)

 

        success_count = 0

        failure_count = 0

 

        min_time = 0

        max_time = 0

        total_time = 0

        print_info = ""
 

        for request_text in _request_text_generator(input_request_file):

            request_proto = BidAgentRequest()

 

            try:

                text_format.Parse(request_text, request_proto)

            except ParseError as exc:

                msg = 'Could not parse bid agent request: {}. \nRequest: {}'.format(exc, request_text)

                logger.error(msg)

                # Intentionally write errors into output file so that (1) responses (errors) will

                # be aligned with requests and (2) user can do analysis in the output file.

                _write_response(output_file, msg)

                failure_count += 1

                continue

 

            try:

                logger.debug('Sending request: {}'.format(request_proto))

                current_milli_time = lambda: int(round(time.time() * 1000))

                start_time_milli = current_milli_time()

                response = session.post(opts.bid_endpoint,

                                        data=request_proto.SerializeToString(),

                                        timeout=_HTTP_TIMEOUT_S)

                elapsed_time_milli = current_milli_time() - start_time_milli

                total_time = total_time + elapsed_time_milli

 

                if min_time == 0 or min_time > elapsed_time_milli:
                    print_info = print_info + '\n' + 'replacing min with: {}'.format(elapsed_time_milli)
                    min_time = elapsed_time_milli

 

                if max_time == 0 or max_time < elapsed_time_milli:
                    print_info = print_info + '\n' + 'replacing max with: {}'.format(elapsed_time_milli)
                    max_time = elapsed_time_milli

 

            except Exception as exc:

                msg = 'Error in sending http request: {}'.format(exc)

                logger.error(msg)

                # Intentionally write errors into output file.

                _write_response(output_file, msg)

                failure_count += 1

                continue

 

            try:

                response_message = _get_response_message(response)

            except DecodeError as exc:

                msg = 'Failed to deserialize response body: {}'.format(exc)

                logger.error(msg)

                # Intentionally write errors into output file.

                _write_response(output_file, msg)

                failure_count += 1

                continue

 

            _write_response(output_file, response_message)

            success_count += 1

            logger.debug('Successfully processed request: {}'.format(request_proto))

 

        input_request_file.close()

    finally:

        if output_file:

            output_file.close()

 

        logger.info('Print info: {}'.format(print_info))
        logger.info('Finished processing all requests. Success count: {}, failure count: {}'

                    .format(success_count, failure_count))

        average_time = total_time / (success_count + failure_count)

        logger.info('Stats: Average latency: {} ms, Min latency: {} ms, Max latency: {} ms'

                    .format(average_time, min_time, max_time))

 

    return 0
예제 #27
0
    def testIsDesiredOutputEvent(self):
        output_event = text_format.Parse(
            """
        type: OUTPUT
        path {
          steps {
            key: 'right_key'
          }
          steps {
            index: 1
          }
        }
        """, metadata_store_pb2.Event())
        declared_output_event = text_format.Parse(
            """
        type: DECLARED_OUTPUT
        path {
          steps {
            key: 'right_key'
          }
          steps {
            index: 1
          }
        }
        """, metadata_store_pb2.Event())
        internal_output_event = text_format.Parse(
            """
        type: INTERNAL_OUTPUT
        path {
          steps {
            key: 'right_key'
          }
          steps {
            index: 1
          }
        }
        """, metadata_store_pb2.Event())
        input_event = text_format.Parse(
            """
        type: INPUT
        path {
          steps {
            key: 'right_key'
          }
          steps {
            index: 1
          }
        }
        """, metadata_store_pb2.Event())
        empty_event = text_format.Parse('type: OUTPUT',
                                        metadata_store_pb2.Event())

        self.assertTrue(
            event_lib.is_valid_output_event(output_event, 'right_key'))
        self.assertTrue(
            event_lib.is_valid_output_event(declared_output_event,
                                            'right_key'))
        self.assertTrue(
            event_lib.is_valid_output_event(internal_output_event,
                                            'right_key'))
        self.assertFalse(
            event_lib.is_valid_output_event(output_event, 'wrong_key'))
        self.assertFalse(
            event_lib.is_valid_output_event(input_event, 'right_key'))
        self.assertFalse(
            event_lib.is_valid_output_event(empty_event, 'right_key'))
        self.assertTrue(event_lib.is_valid_output_event(empty_event))
예제 #28
0
def InterpretCompletedOp(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str,
                                     op_attribute_pb.OpAttribute())
    blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    _InterpretCompletedOp(op_attribute, parallel_conf, blob_register)
    gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
예제 #29
0
def window_selector_config(flags_obj):
  """Creates a WindowSelectorOptions proto based on input and default settings.

  Args:
    flags_obj: configuration FLAGS.

  Returns:
    realigner_pb2.WindowSelector protobuf.

  Raises:
    ValueError: If either ws_{min,max}_supporting_reads are set and
      ws_use_window_selector_model is True.
      Or if ws_window_selector_model > ws_max_num_supporting_reads.
      Or if ws_use_window_selector_model is False and
      ws_window_selector_model is not None.
  """
  if not flags_obj.ws_use_window_selector_model:
    if flags_obj.ws_window_selector_model is not None:
      raise ValueError('Cannot specify a ws_window_selector_model '
                       'if ws_use_window_selector_model is False.')

    min_num_supporting_reads = (
        _DEFAULT_MIN_SUPPORTING_READS
        if flags_obj.ws_min_num_supporting_reads == _UNSET_WS_INT_FLAG else
        flags_obj.ws_min_num_supporting_reads)
    max_num_supporting_reads = (
        _DEFAULT_MAX_SUPPORTING_READS
        if flags_obj.ws_max_num_supporting_reads == _UNSET_WS_INT_FLAG else
        flags_obj.ws_max_num_supporting_reads)
    window_selector_model = realigner_pb2.WindowSelectorModel(
        model_type=realigner_pb2.WindowSelectorModel.VARIANT_READS,
        variant_reads_model=realigner_pb2.WindowSelectorModel
        .VariantReadsThresholdModel(
            min_num_supporting_reads=min_num_supporting_reads,
            max_num_supporting_reads=max_num_supporting_reads))
  else:
    if flags_obj.ws_min_num_supporting_reads != _UNSET_WS_INT_FLAG:
      raise ValueError('Cannot use both ws_min_num_supporting_reads and '
                       'ws_use_window_selector_model flags.')
    if flags_obj.ws_max_num_supporting_reads != _UNSET_WS_INT_FLAG:
      raise ValueError('Cannot use both ws_max_num_supporting_reads and '
                       'ws_use_window_selector_model flags.')

    if flags_obj.ws_window_selector_model is None:
      window_selector_model = _ALLELE_COUNT_LINEAR_MODEL_DEFAULT
    else:
      with tf.io.gfile.GFile(flags_obj.ws_window_selector_model) as f:
        window_selector_model = text_format.Parse(
            f.read(), realigner_pb2.WindowSelectorModel())

  if (window_selector_model.model_type ==
      realigner_pb2.WindowSelectorModel.VARIANT_READS):
    model = window_selector_model.variant_reads_model
    if model.max_num_supporting_reads < model.min_num_supporting_reads:
      raise ValueError('ws_min_supporting_reads should be smaller than '
                       'ws_max_supporting_reads.')

  ws_config = realigner_pb2.WindowSelectorOptions(
      min_mapq=flags_obj.ws_min_mapq,
      min_base_quality=flags_obj.ws_min_base_quality,
      min_windows_distance=flags_obj.ws_min_windows_distance,
      max_window_size=flags_obj.ws_max_window_size,
      region_expansion_in_bp=flags_obj.ws_region_expansion_in_bp,
      window_selector_model=window_selector_model)

  return ws_config
예제 #30
0
def _parse_data_point(s: Text) -> delay_model_pb2.DataPoint:
  """Parses a text proto representation of a DataPoint."""
  return text_format.Parse(s, delay_model_pb2.DataPoint())