def _parse_operation(s: Text) -> delay_model_pb2.Operation: """Parses a text proto representation of an Operation.""" return text_format.Parse(s, delay_model_pb2.Operation())
from tensorflow_model_analysis.eval_saved_model import testutil from tensorflow_model_analysis.extractors import sql_slice_key_extractor from tensorflow_model_analysis.proto import config_pb2 from tfx_bsl.tfxio import tf_example_record from google.protobuf import text_format from tensorflow_metadata.proto.v0 import schema_pb2 _SCHEMA = text_format.Parse( """ feature { name: "fixed_int" type: INT } feature { name: "fixed_float" type: FLOAT } feature { name: "fixed_string" type: BYTES } """, schema_pb2.Schema()) class SqlSliceKeyExtractorTest(testutil.TensorflowModelAnalysisTest): def testSqlSliceKeyExtractor(self): eval_config = config_pb2.EvalConfig(slicing_specs=[ config_pb2.SlicingSpec(slice_keys_sql=""" SELECT
text_format.Parse( """ context { features { feature { key: "ctx.int" # dot in the feature name is intended. value { int64_list { value: [1, 2] } } } feature { key: "ctx.float" value { float_list { value: [1.0, 2.0] } } } feature { key: "ctx.bytes" value { bytes_list { value: [] } } } } } examples { features { feature { key: "example_int" value { int64_list { value: [11] } } } feature { key: "example_float" value { float_list { value: [11.0, 12.0] } } } feature { key: "example_bytes" value { bytes_list { value: ["u", "v"] } } } } } examples { features { feature { key: "example_int" value { int64_list { value: [22] } } } # example_float is not present. feature { key: "example_bytes" value { bytes_list { value: ["w"] } } } } } """, input_pb2.ExampleListWithContext()).SerializeToString(),
def testUncertaintyValuedMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'one_dim': types.ValueWithTDistribution(2.0, 1.0, 3, 2.0), 'nans': types.ValueWithTDistribution(float('nan'), float('nan'), -1, float('nan')), } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "one_dim" value { bounded_value { value { value: 2.0 } lower_bound { value: -1.1824463 } upper_bound { value: 5.1824463 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: -1.1824463 } upper_bound { value: 5.1824463 } t_distribution_value { sample_mean { value: 2.0 } sample_standard_deviation { value: 1.0 } sample_degrees_of_freedom { value: 3 } unsampled_value { value: 2.0 } } } } } metrics { key: "nans" value { bounded_value { value { value: nan } lower_bound { value: nan } upper_bound { value: nan } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: nan } upper_bound { value: nan } t_distribution_value { sample_mean { value: nan } sample_standard_deviation { value: nan } sample_degrees_of_freedom { value: -1 } unsampled_value { value: nan } } } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), []) self.assertProtoEquals(expected_metrics_for_slice, got)
def testConvertSlicePlotsToProto(self): slice_key = _make_slice_key('fruit', 'apple') plot_key = metric_types.PlotKey(name='calibration_plot', output_name='output_name') calibration_plot = text_format.Parse( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } """, metrics_for_slice_pb2.CalibrationHistogramBuckets()) expected_plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: 'fruit' bytes_value: 'apple' } } plot_keys_and_values { key { output_name: "output_name" } value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_plots_to_proto( (slice_key, { plot_key: calibration_plot }), None) self.assertProtoEquals(expected_plots_for_slice, got)
def testWriteMetricsAndPlots(self, output_file_format): metrics_file = os.path.join(self._getTempDir(), 'metrics') plots_file = os.path.join(self._getTempDir(), 'plots') temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_config = config.EvalConfig( model_specs=[config.ModelSpec()], options=config.Options( disabled_outputs={'values': ['eval_config.json']})) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.calibration_plot_and_prediction_histogram( num_buckets=2) ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] evaluators = [ metrics_and_plots_evaluator.MetricsAndPlotsEvaluator( eval_shared_model) ] output_paths = { constants.METRICS_KEY: metrics_file, constants.PLOTS_KEY: plots_file } writers = [ metrics_plots_and_validations_writer. MetricsPlotsAndValidationsWriter( output_paths, eval_config=eval_config, add_metrics_callbacks=eval_shared_model.add_metrics_callbacks, output_file_format=output_file_format) ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=1.0, label=1.0) # pylint: disable=no-value-for-parameter _ = (pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), ]) | 'ExtractEvaluateAndWriteResults' >> model_eval_lib.ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, extractors=extractors, evaluators=evaluators, writers=writers)) # pylint: enable=no-value-for-parameter expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "average_loss" value { double_value { value: 0.5 } } } metrics { key: "post_export_metrics/example_count" value { double_value { value: 2.0 } } } """, metrics_for_slice_pb2.MetricsForSlice()) metric_records = list( metrics_plots_and_validations_writer.load_and_deserialize_metrics( metrics_file)) self.assertLen(metric_records, 1, 'metrics: %s' % metric_records) self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) expected_plots_for_slice = text_format.Parse( """ slice_key {} plots { key: "post_export_metrics" value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf num_weighted_examples {} total_weighted_label {} total_weighted_refined_prediction {} } buckets { upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { } total_weighted_label {} total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 1.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) plot_records = list( metrics_plots_and_validations_writer.load_and_deserialize_plots( plots_file)) self.assertLen(plot_records, 1, 'plots: %s' % plot_records) self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
def testConvertSliceMetricsToProtoMetricsRanges(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8), metric_keys.AUPRC: 0.1, metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05, metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17, metric_keys.AUC: 0.2, metric_keys.lower_bound_key(metric_keys.AUC): 0.1, metric_keys.upper_bound_key(metric_keys.AUC): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals(expected_metrics_for_slice, got)
# 该脚本用于更新tensorflow/serving中的models.config import grpc from google.protobuf import text_format from tensorflow_serving.apis import model_service_pb2_grpc, model_management_pb2 from tensorflow_serving.config import model_server_config_pb2 from tensorflow_serving.sources.storage_path.file_system_storage_path_source_pb2 import FileSystemStoragePathSourceConfig # models.config所在路径 model_config_file_path = "./models.config" with open(model_config_file_path, 'r+') as f: config_ini = f.read() request = model_management_pb2.ReloadConfigRequest() model_server_config = model_server_config_pb2.ModelServerConfig() config_list = model_server_config_pb2.ModelConfigList() model_server_config = text_format.Parse(text=config_ini, message=model_server_config) # Create a config to add to the list of served models one_config = config_list.config.add() one_config.name = "lmj" one_config.base_path = "/models/lmj" one_config.model_platform = "tensorflow" servable_version_policy = FileSystemStoragePathSourceConfig( ).ServableVersionPolicy() one_config.model_version_policy.all.CopyFrom(servable_version_policy.All()) model_server_config.model_config_list.MergeFrom(config_list) request.config.CopyFrom(model_server_config) # 服务地址:192.168.1.168:8510, 其中8510对应tensorflow/serving的8500端口 channel = grpc.insecure_channel('192.168.1.193:8510')
def testUnbatchExtractor(self): model_spec = config.ModelSpec(label_key='label', example_weight_key='example_weight') eval_config = config.EvalConfig(model_specs=[model_spec]) input_extractor = batched_input_extractor.BatchedInputExtractor( eval_config) unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() schema = text_format.Parse( """ feature { name: "label" type: FLOAT } feature { name: "example_weight" type: FLOAT } feature { name: "fixed_int" type: INT } feature { name: "fixed_float" type: FLOAT } feature { name: "fixed_string" type: BYTES } """, schema_pb2.Schema()) tfx_io = test_util.InMemoryTFExampleRecord( schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY) examples = [ self._makeExample(label=1.0, example_weight=0.5, fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1'), self._makeExample(label=0.0, example_weight=0.0, fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2'), self._makeExample(label=0.0, example_weight=1.0, fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3') ] with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create( [e.SerializeToString() for e in examples], reshuffle=False) | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() | input_extractor.stage_name >> input_extractor.ptransform | unbatch_inputs_extractor.stage_name >> unbatch_inputs_extractor.ptransform) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 3) self.assertDictElementsAlmostEqual( got[0][constants.FEATURES_KEY], { 'fixed_int': np.array([1]), 'fixed_float': np.array([1.0]), }) self.assertEqual( got[0][constants.FEATURES_KEY]['fixed_string'], np.array([b'fixed_string1'])) self.assertAlmostEqual(got[0][constants.LABELS_KEY], np.array([1.0])) self.assertAlmostEqual( got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5])) self.assertDictElementsAlmostEqual( got[1][constants.FEATURES_KEY], { 'fixed_int': np.array([1]), 'fixed_float': np.array([1.0]), }) self.assertEqual( got[1][constants.FEATURES_KEY]['fixed_string'], np.array([b'fixed_string2'])) self.assertAlmostEqual(got[1][constants.LABELS_KEY], np.array([0.0])) self.assertAlmostEqual( got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0])) self.assertDictElementsAlmostEqual( got[2][constants.FEATURES_KEY], { 'fixed_int': np.array([2]), 'fixed_float': np.array([0.0]), }) self.assertEqual( got[2][constants.FEATURES_KEY]['fixed_string'], np.array([b'fixed_string3'])) self.assertAlmostEqual(got[2][constants.LABELS_KEY], np.array([0.0])) self.assertAlmostEqual( got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0])) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')
def _get_csv_test(self, delimiter=',', with_header=False): fields = [['feature1', 'feature2'], ['1.0', 'aa'], ['2.0', 'bb'], ['3.0', 'cc'], ['4.0', 'dd'], ['5.0', 'ee'], ['6.0', 'ff'], ['7.0', 'gg'], ['', '']] records = [] for row in fields: records.append(delimiter.join(row)) expected_result = text_format.Parse( """ datasets { num_examples: 8 features { path { step: "feature1" } type: FLOAT num_stats { common_stats { num_non_missing: 7 num_missing: 1 min_num_values: 1 max_num_values: 1 avg_num_values: 1.0 num_values_histogram { buckets { low_value: 1.0 high_value: 1.0 sample_count: 3.5 } buckets { low_value: 1.0 high_value: 1.0 sample_count: 3.5 } type: QUANTILES } tot_num_values: 7 } mean: 4.0 std_dev: 2.0 min: 1.0 max: 7.0 median: 4.0 histograms { buckets { low_value: 1.0 high_value: 4.0 sample_count: 3.01 } buckets { low_value: 4.0 high_value: 7.0 sample_count: 3.99 } } histograms { buckets { low_value: 1.0 high_value: 4.0 sample_count: 3.5 } buckets { low_value: 4.0 high_value: 7.0 sample_count: 3.5 } type: QUANTILES } } } features { path { step: "feature2" } type: STRING string_stats { common_stats { num_non_missing: 7 num_missing: 1 min_num_values: 1 max_num_values: 1 avg_num_values: 1.0 num_values_histogram { buckets { low_value: 1.0 high_value: 1.0 sample_count: 3.5 } buckets { low_value: 1.0 high_value: 1.0 sample_count: 3.5 } type: QUANTILES } tot_num_values: 7 } unique: 7 top_values { value: "gg" frequency: 1.0 } top_values { value: "ff" frequency: 1.0 } avg_length: 2.0 rank_histogram { buckets { label: "gg" sample_count: 1.0 } buckets { low_rank: 1 high_rank: 1 label: "ff" sample_count: 1.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) if with_header: return (records, None, expected_result) return (records[1:], records[0].split(delimiter), expected_result)
def test_stats_gen_with_csv_with_schema(self): records = ['feature1', '1'] input_data_path = self._write_records_to_csv(records, self._get_temp_dir(), 'input_data.csv') schema = text_format.Parse( """ feature { name: "feature1" type: BYTES } """, schema_pb2.Schema()) expected_result = text_format.Parse( """ datasets { num_examples: 1 features { path { step: "feature1" } type: STRING string_stats { common_stats { num_non_missing: 1 min_num_values: 1 max_num_values: 1 avg_num_values: 1.0 num_values_histogram { buckets { low_value: 1.0 high_value: 1.0 sample_count: 0.5 } buckets { low_value: 1.0 high_value: 1.0 sample_count: 0.5 } type: QUANTILES } tot_num_values: 1 } unique: 1 top_values { value: "1" frequency: 1.0 } avg_length: 1.0 rank_histogram { buckets { label: "1" sample_count: 1.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) self._default_stats_options.schema = schema self._default_stats_options.infer_type_from_schema = True result = stats_gen_lib.generate_statistics_from_csv( data_location=input_data_path, delimiter=',', stats_options=self._default_stats_options) compare_fn = test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result) compare_fn([result])
def test_stats_gen_with_tfrecords_of_tfexamples(self, compression_type): examples = [ self._make_example({ 'a': ('float', [1.0, 2.0]), 'b': ('bytes', [b'a', b'b', b'c', b'e']) }), self._make_example({ 'a': ('float', [3.0, 4.0, float('nan'), 5.0]), 'b': ('bytes', [b'a', b'c', b'd', b'a']) }), self._make_example({ 'a': ('float', [1.0]), 'b': ('bytes', [b'a', b'b', b'c', b'd']) }) ] tf_compression_lookup = { CompressionTypes.AUTO: tf.compat.v1.python_io.TFRecordCompressionType.NONE, CompressionTypes.GZIP: tf.compat.v1.python_io.TFRecordCompressionType.GZIP } input_data_path = self._write_tfexamples_to_tfrecords( examples, tf_compression_lookup[compression_type]) expected_result = text_format.Parse( """ datasets { num_examples: 3 features { path { step: "a" } type: FLOAT num_stats { common_stats { num_non_missing: 3 num_missing: 0 min_num_values: 1 max_num_values: 4 avg_num_values: 2.33333333 tot_num_values: 7 num_values_histogram { buckets { low_value: 1.0 high_value: 2.0 sample_count: 1.5 } buckets { low_value: 2.0 high_value: 4.0 sample_count: 1.5 } type: QUANTILES } } mean: 2.66666666 std_dev: 1.49071198 num_zeros: 0 min: 1.0 max: 5.0 median: 3.0 histograms { num_nan: 1 buckets { low_value: 1.0 high_value: 3.0 sample_count: 3.0 } buckets { low_value: 3.0 high_value: 5.0 sample_count: 3.0 } type: STANDARD } histograms { num_nan: 1 buckets { low_value: 1.0 high_value: 3.0 sample_count: 3.0 } buckets { low_value: 3.0 high_value: 5.0 sample_count: 3.0 } type: QUANTILES } } } features { path { step: "b" } type: STRING string_stats { common_stats { num_non_missing: 3 min_num_values: 4 max_num_values: 4 avg_num_values: 4.0 tot_num_values: 12 num_values_histogram { buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.5 } buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.5 } type: QUANTILES } } unique: 5 top_values { value: "a" frequency: 4.0 } top_values { value: "c" frequency: 3.0 } avg_length: 1.0 rank_histogram { buckets { low_rank: 0 high_rank: 0 label: "a" sample_count: 4.0 } buckets { low_rank: 1 high_rank: 1 label: "c" sample_count: 3.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) result = stats_gen_lib.generate_statistics_from_tfrecord( data_location=input_data_path, stats_options=self._default_stats_options, compression_type=compression_type) compare_fn = test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result) compare_fn([result])
def _ReadProto(proto, path): with open(path, 'r', encoding='utf-8') as f: proto = text_format.Parse(f.read(), proto) return proto
def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored): job_conf = text_format.Parse(job_conf_str, job_conf_pb.JobConfigProto()) parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf()) return compiler.MakeInitialScope( job_conf, parallel_conf.device_tag, list(parallel_conf.device_name), is_mirrored ).symbol_id
def read_project(f): return text_format.Parse(f.read(), config_pb2.Project())
def test_find_significant_slices(self): metrics = [ text_format.Parse( """ slice_key { } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 1500 } lower_bound { value: 1500 } upper_bound { value: 1500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 1500 } upper_bound { value: 1500 } t_distribution_value { sample_mean { value: 1500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 1500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[1.0, 6.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.4 } lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } t_distribution_value { sample_mean { value: 0.4 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.4 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[6.0, 12.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.79 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.79 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.79 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()) ] self.assertCountEqual( auto_slicing_util.find_significant_slices( metrics, metric_key='accuracy', comparison_type='LOWER'), [ auto_slicing_util.SliceComparisonResult( slice_key=(('age', '[1.0, 6.0)'),), num_examples=500.0, slice_metric=0.4, base_metric=0.8, p_value=0.0, effect_size=4.0, raw_slice_metrics=metrics[1]) ]) self.assertCountEqual( auto_slicing_util.find_significant_slices( metrics, metric_key='accuracy', comparison_type='HIGHER'), [ auto_slicing_util.SliceComparisonResult( slice_key=(('age', '[12.0, 18.0)'),), num_examples=500.0, slice_metric=0.9, base_metric=0.8, p_value=7.356017854191938e-70, effect_size=0.9999999999999996, raw_slice_metrics=metrics[3]), auto_slicing_util.SliceComparisonResult( slice_key=(('country', 'USA'),), num_examples=500.0, slice_metric=0.9, base_metric=0.8, p_value=7.356017854191938e-70, effect_size=0.9999999999999996, raw_slice_metrics=metrics[4]), auto_slicing_util.SliceComparisonResult( slice_key=(('age', '[12.0, 18.0)'), ('country', 'USA')), num_examples=500.0, slice_metric=0.9, base_metric=0.8, p_value=7.356017854191938e-70, effect_size=0.9999999999999996, raw_slice_metrics=metrics[5]) ])
def read_config(f): return text_format.Parse(f.read(), config_pb2.Config())
def test_revert_slice_keys_for_transformed_features(self): statistics = text_format.Parse( """ datasets{ num_examples: 1500 features { path { step: 'country' } type: STRING string_stats { unique: 10 } } features { path { step: 'age' } type: INT num_stats { common_stats { num_non_missing: 1500 min_num_values: 1 max_num_values: 1 } min: 1 max: 18 histograms { buckets { low_value: 1 high_value: 6.0 sample_count: 500 } buckets { low_value: 6.0 high_value: 12.0 sample_count: 500 } buckets { low_value: 12.0 high_value: 18.0 sample_count: 500 } type: QUANTILES } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) metrics = [ text_format.Parse(""" slice_key { } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'transformed_age' int64_value: 1 } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'transformed_age' int64_value: 2 } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } """, metrics_for_slice_pb2.MetricsForSlice()) ] expected_metrics = [ text_format.Parse(""" slice_key { } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[1.0, 6.0)' } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[6.0, 12.0)' } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } """, metrics_for_slice_pb2.MetricsForSlice()) ] actual = auto_slicing_util.revert_slice_keys_for_transformed_features( metrics, statistics) self.assertEqual(actual, expected_metrics)
def testConvertSliceMetricsToProtoConfusionMatrices(self): slice_key = _make_slice_key() thresholds = [0.25, 0.75, 1.00] matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5], [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]] slice_metrics = { metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES: matrices, metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS: thresholds, } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "post_export_metrics/confusion_matrix_at_thresholds" value { confusion_matrix_at_thresholds { matrices { threshold: 0.25 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 2.0 precision: 1.0 recall: 1.0 bounded_false_negatives { value { value: 0.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 2.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 1.0 } } t_distribution_false_negatives { unsampled_value { value: 0.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 2.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: 1.0 } } t_distribution_recall { unsampled_value { value: 1.0 } } } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 precision: 1.0 recall: 0.5 bounded_false_negatives { value { value: 1.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 1.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 0.5 } } t_distribution_false_negatives { unsampled_value { value: 1.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 1.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: 1.0 } } t_distribution_recall { unsampled_value { value: 0.5 } } } matrices { threshold: 1.00 false_negatives: 2.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 precision: nan recall: 0.0 bounded_false_negatives { value { value: 2.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 0.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: nan } } bounded_recall { value { value: 0.0 } } t_distribution_false_negatives { unsampled_value { value: 2.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 0.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: nan } } t_distribution_recall { unsampled_value { value: 0.0 } } } } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.confusion_matrix_at_thresholds(thresholds)]) self.assertProtoEquals(expected_metrics_for_slice, got)
def str_to_bond_topology(s): bt = dataset_pb2.BondTopology() text_format.Parse(s, bt) return bt
def testConvertSliceMetricsToProtoFromLegacyStrings(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': 0.8, metric_keys.AUPRC: 0.1, metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05, metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17, metric_keys.AUC: 0.2, metric_keys.lower_bound_key(metric_keys.AUC): 0.1, metric_keys.upper_bound_key(metric_keys.AUC): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals(expected_metrics_for_slice, got)
def generate_parallel_module(modules: Sequence[ module_signature_mod.ModuleGeneratorResult], module_name: str) -> str: """Generates a module composed of instantiated instances of the given modules. Each module in 'modules' is instantiated exactly once in a enclosing, composite module. Inputs to each instantiation are provided by inputs to the enclosing module. For example, if given two modules, add8_module and add16_module, the generated module might look like: module add8_module( input wire clk, input wire [7:0] op0, input wire [7:0] op1, output wire [7:0] out ); // contents of module elided... endmodule module add16_module( input wire clk, input wire [15:0] op0, input wire [15:0] op1, output wire [15:0] out ); // contents of module elided... endmodule module foo( input wire clk, input wire [7:0] add8_module_op0, input wire [7:0] add8_module_op1, output wire [7:0] add8_module_out, input wire [15:0] add16_module_op0, input wire [15:0] add16_module_op1, output wire [15:0] add16_module_out, ); add8_module add8_module_inst( .clk(clk), .op0(add8_module_op0), .op1(add8_module_op1), .out(add8_module_out) ); add16_module add16_module_inst( .clk(clk), .op0(add16_module_op0), .op1(add16_module_op1), .out(add16_module_out) ); endmodule Arguments: modules: Modules to include instantiate. module_name: Name of the module containing the instantiated input modules. Returns: Verilog text containing the composite module and component modules. """ module_protos = [ text_format.Parse(m.signature.as_text_proto(), module_signature_pb2.ModuleSignatureProto()) for m in modules ] ports = ['input wire clk'] for module in module_protos: for data_port in module.data_ports: width_str = f'[{data_port.width - 1}:0]' signal_name = f'{module.module_name}_{data_port.name}' if data_port.direction == module_signature_pb2.DIRECTION_INPUT: ports.append(f'input wire {width_str} {signal_name}') elif data_port.direction == module_signature_pb2.DIRECTION_OUTPUT: ports.append(f'output wire {width_str} {signal_name}') header = """module {module_name}(\n{ports}\n);""".format( module_name=module_name, ports=',\n'.join(f' {p}' for p in ports)) instantiations = [] for module in module_protos: connections = ['.clk(clk)'] for data_port in module.data_ports: connections.append( f'.{data_port.name}({module.module_name}_{data_port.name})') instantiations.append( ' {name} {name}_inst(\n{connections}\n );'.format( name=module.module_name, connections=',\n'.join(f' {c}' for c in connections))) return '{modules}\n\n{header}\n{instantiations}\nendmodule\n'.format( modules='\n\n'.join(m.verilog_text for m in modules), header=header, instantiations='\n'.join(instantiations))
def testWriteValidationResults(self, output_file_format): model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() eval_shared_model = self._build_keras_model(model_dir, mul=0) baseline_eval_shared_model = self._build_keras_model(baseline_dir, mul=1) validations_file = os.path.join(self._getTempDir(), constants.VALIDATIONS_KEY) schema = text_format.Parse( """ tensor_representation_group { key: "" value { tensor_representation { key: "input" value { dense_tensor { column_name: "input" shape { dim { size: 1 } } } } } } } feature { name: "input" type: FLOAT } feature { name: "label" type: FLOAT } feature { name: "example_weight" type: FLOAT } feature { name: "extra_feature" type: BYTES } """, schema_pb2.Schema()) tfx_io = test_util.InMemoryTFExampleRecord( schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN) tensor_adapter_config = tensor_adapter.TensorAdapterConfig( arrow_schema=tfx_io.ArrowSchema(), tensor_representations=tfx_io.TensorRepresentations()) examples = [ self._makeExample(input=0.0, label=1.0, example_weight=1.0, extra_feature='non_model_feature'), self._makeExample(input=1.0, label=0.0, example_weight=0.5, extra_feature='non_model_feature'), ] slicing_specs = [ config.SlicingSpec(), config.SlicingSpec(feature_keys=['slice_does_not_exist']) ] eval_config = config.EvalConfig( model_specs=[ config.ModelSpec(name='candidate', label_key='label', example_weight_key='example_weight'), config.ModelSpec(name='baseline', label_key='label', example_weight_key='example_weight', is_baseline=True) ], slicing_specs=slicing_specs, metrics_specs=[ config.MetricsSpec( metrics=[ config.MetricConfig( class_name='WeightedExampleCount', per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slicing_specs, # 1.5 < 1, NOT OK. threshold=config.MetricThreshold( value_threshold=config. GenericValueThreshold( upper_bound={'value': 1}))) ]), config.MetricConfig( class_name='ExampleCount', # 2 > 10, NOT OK. threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold( lower_bound={'value': 10}))), config.MetricConfig( class_name='MeanLabel', # 0 > 0 and 0 > 0%?: NOT OK. threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( direction=config.MetricDirection. HIGHER_IS_BETTER, relative={'value': 0}, absolute={'value': 0}))), config.MetricConfig( # MeanPrediction = (0+0)/(1+0.5) = 0 class_name='MeanPrediction', # -.01 < 0 < .01, OK. # Diff% = -.333/.333 = -100% < -99%, OK. # Diff = 0 - .333 = -.333 < 0, OK. threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold( upper_bound={'value': .01}, lower_bound={'value': -.01}), change_threshold=config.GenericChangeThreshold( direction=config.MetricDirection. LOWER_IS_BETTER, relative={'value': -.99}, absolute={'value': 0}))) ], model_names=['candidate', 'baseline']), ], options=config.Options( disabled_outputs={'values': ['eval_config.json']}), ) slice_spec = [ slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs ] eval_shared_models = { 'candidate': eval_shared_model, 'baseline': baseline_eval_shared_model } extractors = [ batched_input_extractor.BatchedInputExtractor(eval_config), batched_predict_extractor_v2.BatchedPredictExtractor( eval_shared_model=eval_shared_models, eval_config=eval_config, tensor_adapter_config=tensor_adapter_config), unbatch_extractor.UnbatchExtractor(), slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec) ] evaluators = [ metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator( eval_config=eval_config, eval_shared_model=eval_shared_models) ] output_paths = { constants.VALIDATIONS_KEY: validations_file, } writers = [ metrics_plots_and_validations_writer. MetricsPlotsAndValidationsWriter( output_paths, eval_config=eval_config, add_metrics_callbacks=[], output_file_format=output_file_format) ] with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter _ = ( pipeline | 'Create' >> beam.Create( [e.SerializeToString() for e in examples]) | 'BatchExamples' >> tfx_io.BeamSource() | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() | 'ExtractEvaluate' >> model_eval_lib.ExtractAndEvaluate( extractors=extractors, evaluators=evaluators) | 'WriteResults' >> model_eval_lib.WriteResults(writers=writers)) # pylint: enable=no-value-for-parameter validation_result = (metrics_plots_and_validations_writer. load_and_deserialize_validation_result( os.path.dirname(validations_file))) expected_validations = [ text_format.Parse( """ metric_key { name: "weighted_example_count" model_name: "candidate" } metric_threshold { value_threshold { upper_bound { value: 1.0 } } } metric_value { double_value { value: 1.5 } } """, validation_result_pb2.ValidationFailure()), text_format.Parse( """ metric_key { name: "example_count" model_name: "candidate" } metric_threshold { value_threshold { lower_bound { value: 10.0 } } } metric_value { double_value { value: 2.0 } } """, validation_result_pb2.ValidationFailure()), text_format.Parse( """ metric_key { name: "mean_label" model_name: "candidate" is_diff: true } metric_threshold { change_threshold { absolute { value: 0.0 } relative { value: 0.0 } direction: HIGHER_IS_BETTER } } metric_value { double_value { value: 0.0 } } """, validation_result_pb2.ValidationFailure()), ] self.assertFalse(validation_result.validation_ok) self.assertLen(validation_result.metric_validations_per_slice, 1) self.assertCountEqual( expected_validations, validation_result.metric_validations_per_slice[0].failures) expected_missing_slices = [ config.SlicingSpec(feature_keys=['slice_does_not_exist']) ] self.assertLen(validation_result.missing_slices, 1) self.assertCountEqual(expected_missing_slices, validation_result.missing_slices) expected_slicing_details = [ text_format.Parse( """ slicing_spec { } num_matching_slices: 1 """, validation_result_pb2.SlicingDetails()), ] self.assertLen(validation_result.validation_details.slicing_details, 1) self.assertCountEqual( expected_slicing_details, validation_result.validation_details.slicing_details)
def testConvertAttributionsProto(self): attributions_for_slice = text_format.Parse( """ slice_key {} attributions_keys_and_values { key { name: "total_attributions" } values { key: "feature1" value: { double_value { value: 1.0 } } } values { key: "feature2" value: { double_value { value: 2.0 } } } } attributions_keys_and_values { key { name: "total_attributions" output_name: "output1" sub_key: { class_id: { value: 1 } } } values { key: "feature1" value: { double_value { value: 1.0 } } } }""", metrics_for_slice_pb2.AttributionsForSlice()) got = util.convert_attributions_proto_to_dict(attributions_for_slice, None) self.assertEqual(got, ((), { '': { '': { 'total_attributions': { 'feature2': { 'doubleValue': 2.0 }, 'feature1': { 'doubleValue': 1.0 } } } }, 'output1': { 'classId:1': { 'total_attributions': { 'feature1': { 'doubleValue': 1.0 } } } } }))
def _get_version_config(version_config_path): with open(version_config_path) as f: return text_format.Parse(f.read(), version_config_pb2.VersionConfig())
def main(): opts = parse_command_line() logger.setLevel(logging._levelNames[opts.log_level.upper()]) logger.addHandler(logging.StreamHandler()) logger.info('Endpoint: {}'.format(opts.bid_endpoint)) headers = { 'Content-type': 'application/x-protobuf', } if opts.header_secret: headers['beeswax-auth-secret'] = opts.header_secret try: input_request_file = open(opts.path_to_requests_file, 'rb') except (IOError, OSError) as exc: logger.error('Could not open bid agent requests input file: {}'.format(exc)) return -1 output_file = None if opts.path_to_responses_file: try: output_file = open(opts.path_to_responses_file, 'wb') except (IOError, OSError) as exc: logger.error('Could not open bid agent responses output file: {}'.format(exc)) return -1 try: session = requests.Session() session.headers.update(headers) success_count = 0 failure_count = 0 min_time = 0 max_time = 0 total_time = 0 print_info = "" for request_text in _request_text_generator(input_request_file): request_proto = BidAgentRequest() try: text_format.Parse(request_text, request_proto) except ParseError as exc: msg = 'Could not parse bid agent request: {}. \nRequest: {}'.format(exc, request_text) logger.error(msg) # Intentionally write errors into output file so that (1) responses (errors) will # be aligned with requests and (2) user can do analysis in the output file. _write_response(output_file, msg) failure_count += 1 continue try: logger.debug('Sending request: {}'.format(request_proto)) current_milli_time = lambda: int(round(time.time() * 1000)) start_time_milli = current_milli_time() response = session.post(opts.bid_endpoint, data=request_proto.SerializeToString(), timeout=_HTTP_TIMEOUT_S) elapsed_time_milli = current_milli_time() - start_time_milli total_time = total_time + elapsed_time_milli if min_time == 0 or min_time > elapsed_time_milli: print_info = print_info + '\n' + 'replacing min with: {}'.format(elapsed_time_milli) min_time = elapsed_time_milli if max_time == 0 or max_time < elapsed_time_milli: print_info = print_info + '\n' + 'replacing max with: {}'.format(elapsed_time_milli) max_time = elapsed_time_milli except Exception as exc: msg = 'Error in sending http request: {}'.format(exc) logger.error(msg) # Intentionally write errors into output file. _write_response(output_file, msg) failure_count += 1 continue try: response_message = _get_response_message(response) except DecodeError as exc: msg = 'Failed to deserialize response body: {}'.format(exc) logger.error(msg) # Intentionally write errors into output file. _write_response(output_file, msg) failure_count += 1 continue _write_response(output_file, response_message) success_count += 1 logger.debug('Successfully processed request: {}'.format(request_proto)) input_request_file.close() finally: if output_file: output_file.close() logger.info('Print info: {}'.format(print_info)) logger.info('Finished processing all requests. Success count: {}, failure count: {}' .format(success_count, failure_count)) average_time = total_time / (success_count + failure_count) logger.info('Stats: Average latency: {} ms, Min latency: {} ms, Max latency: {} ms' .format(average_time, min_time, max_time)) return 0
def testIsDesiredOutputEvent(self): output_event = text_format.Parse( """ type: OUTPUT path { steps { key: 'right_key' } steps { index: 1 } } """, metadata_store_pb2.Event()) declared_output_event = text_format.Parse( """ type: DECLARED_OUTPUT path { steps { key: 'right_key' } steps { index: 1 } } """, metadata_store_pb2.Event()) internal_output_event = text_format.Parse( """ type: INTERNAL_OUTPUT path { steps { key: 'right_key' } steps { index: 1 } } """, metadata_store_pb2.Event()) input_event = text_format.Parse( """ type: INPUT path { steps { key: 'right_key' } steps { index: 1 } } """, metadata_store_pb2.Event()) empty_event = text_format.Parse('type: OUTPUT', metadata_store_pb2.Event()) self.assertTrue( event_lib.is_valid_output_event(output_event, 'right_key')) self.assertTrue( event_lib.is_valid_output_event(declared_output_event, 'right_key')) self.assertTrue( event_lib.is_valid_output_event(internal_output_event, 'right_key')) self.assertFalse( event_lib.is_valid_output_event(output_event, 'wrong_key')) self.assertFalse( event_lib.is_valid_output_event(input_event, 'right_key')) self.assertFalse( event_lib.is_valid_output_event(empty_event, 'right_key')) self.assertTrue(event_lib.is_valid_output_event(empty_event))
def InterpretCompletedOp(op_attribute_str, parallel_conf): op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) blob_register = gradient_util.GetDefaultBackwardBlobRegister() _InterpretCompletedOp(op_attribute, parallel_conf, blob_register) gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
def window_selector_config(flags_obj): """Creates a WindowSelectorOptions proto based on input and default settings. Args: flags_obj: configuration FLAGS. Returns: realigner_pb2.WindowSelector protobuf. Raises: ValueError: If either ws_{min,max}_supporting_reads are set and ws_use_window_selector_model is True. Or if ws_window_selector_model > ws_max_num_supporting_reads. Or if ws_use_window_selector_model is False and ws_window_selector_model is not None. """ if not flags_obj.ws_use_window_selector_model: if flags_obj.ws_window_selector_model is not None: raise ValueError('Cannot specify a ws_window_selector_model ' 'if ws_use_window_selector_model is False.') min_num_supporting_reads = ( _DEFAULT_MIN_SUPPORTING_READS if flags_obj.ws_min_num_supporting_reads == _UNSET_WS_INT_FLAG else flags_obj.ws_min_num_supporting_reads) max_num_supporting_reads = ( _DEFAULT_MAX_SUPPORTING_READS if flags_obj.ws_max_num_supporting_reads == _UNSET_WS_INT_FLAG else flags_obj.ws_max_num_supporting_reads) window_selector_model = realigner_pb2.WindowSelectorModel( model_type=realigner_pb2.WindowSelectorModel.VARIANT_READS, variant_reads_model=realigner_pb2.WindowSelectorModel .VariantReadsThresholdModel( min_num_supporting_reads=min_num_supporting_reads, max_num_supporting_reads=max_num_supporting_reads)) else: if flags_obj.ws_min_num_supporting_reads != _UNSET_WS_INT_FLAG: raise ValueError('Cannot use both ws_min_num_supporting_reads and ' 'ws_use_window_selector_model flags.') if flags_obj.ws_max_num_supporting_reads != _UNSET_WS_INT_FLAG: raise ValueError('Cannot use both ws_max_num_supporting_reads and ' 'ws_use_window_selector_model flags.') if flags_obj.ws_window_selector_model is None: window_selector_model = _ALLELE_COUNT_LINEAR_MODEL_DEFAULT else: with tf.io.gfile.GFile(flags_obj.ws_window_selector_model) as f: window_selector_model = text_format.Parse( f.read(), realigner_pb2.WindowSelectorModel()) if (window_selector_model.model_type == realigner_pb2.WindowSelectorModel.VARIANT_READS): model = window_selector_model.variant_reads_model if model.max_num_supporting_reads < model.min_num_supporting_reads: raise ValueError('ws_min_supporting_reads should be smaller than ' 'ws_max_supporting_reads.') ws_config = realigner_pb2.WindowSelectorOptions( min_mapq=flags_obj.ws_min_mapq, min_base_quality=flags_obj.ws_min_base_quality, min_windows_distance=flags_obj.ws_min_windows_distance, max_window_size=flags_obj.ws_max_window_size, region_expansion_in_bp=flags_obj.ws_region_expansion_in_bp, window_selector_model=window_selector_model) return ws_config
def _parse_data_point(s: Text) -> delay_model_pb2.DataPoint: """Parses a text proto representation of a DataPoint.""" return text_format.Parse(s, delay_model_pb2.DataPoint())