def scaffold_assets(self) -> model_card_module.ModelCard: """Generates the model cards tookit assets. Model cards assets include the model card json file and customizable model card UI templates. An assets directory is created if one does not already exist. If the MCT is initialized with a `mlmd_store`, it further auto-populates the model cards properties as well as generating related plots such as model performance and data distributions. Returns: A ModelCard representing the given model. """ model_card = model_card_module.ModelCard() if self._store: model_card = tfx_util.generate_model_card_for_model( self._store, self._artifact_with_model_uri.id) metrics_artifacts = tfx_util.get_metrics_artifacts_for_model( self._store, self._artifact_with_model_uri.id) stats_artifacts = tfx_util.get_stats_artifacts_for_model( self._store, self._artifact_with_model_uri.id) for metrics_artifact in metrics_artifacts: eval_result = tfx_util.read_metrics_eval_result( metrics_artifact.uri) if eval_result is not None: graphics.annotate_eval_result_plots( model_card, eval_result) for stats_artifact in stats_artifacts: train_stats = tfx_util.read_stats_proto( stats_artifact.uri, 'train') eval_stats = tfx_util.read_stats_proto(stats_artifact.uri, 'eval') graphics.annotate_dataset_feature_statistics_plots( model_card, train_stats, eval_stats) # Write JSON file. self._write_file(self._mcta_json_file, model_card.to_json()) # Write UI template files. shutil.copytree(_UI_TEMPLATE_DIR, self._mcta_template_dir) return model_card
def test_train_and_eval_data(self): mc = model_card.ModelCard() mc.model_details = { 'name': 'train and eval model', 'owners': [{ 'name': 'bar', 'contact': '*****@*****.**' }], 'version': { 'name': '0.3', 'date': '2020-01-01', 'diff': 'Updated dataset.', }, 'license': 'Apache 2.0', 'references': ['https://my_model.xyz.com', 'https://example.com'], 'citation': 'https://doi.org/foo/bar', } mc.model_parameters = model_card.ModelParameters( model_architecture='knn', data=model_card.Data(train=_TRAIN_DATA, eval=_EVAL_DATA)) self.assertRenderedTemplate(mc, ExpectedGraphic())
def test_no_quantitative_analysis(self): mc = model_card.ModelCard() mc.model_details = { 'name': 'quantitative analysis', 'overview': 'This demonstrates a quantitative analysis graphic.', 'owners': [{ 'name': 'bar', 'contact': '*****@*****.**' }], 'version': { 'name': '0.4', 'date': '2020-01-01', 'diff': 'Updated dataset.', }, 'license': 'Apache 2.0', 'references': ['https://my_model.xyz.com', 'https://example.com'], 'citation': 'https://doi.org/foo/bar', } mc.model_parameters = model_card.ModelParameters( model_architecture='knn') self.assertRenderedTemplate(mc, ExpectedGraphic())
def generate_model_card_for_model( store: mlmd.MetadataStore, model_id: int, pipeline_types: Optional[PipelineTypes] = None ) -> model_card_module.ModelCard: """Populates model card properties for a model artifact. It traverse the parents and children of the model artifact, and maps related artifact properties and lineage information to model card property. The graphics derived from the artifact payload are handled separately. Args: store: A ml-metadata MetadataStore instance. model_id: The id for the model artifact in the `store`. pipeline_types: An optional set of types if the `store` uses custom types. Returns: A ModelCard data object with the properties. Raises: ValueError: If the `model_id` cannot be resolved as a model artifact in the given `store`. """ if not pipeline_types: pipeline_types = _get_tfx_pipeline_types(store) _validate_model_id(store, pipeline_types.model_type, model_id) model_card = model_card_module.ModelCard() model_details = model_card.model_details trainers = _get_one_hop_executions(store, [model_id], _Direction.ANCESTOR, pipeline_types.trainer_type) if trainers: model_details.name = _property_value(trainers[-1], 'module_file') model_details.version.name = _property_value(trainers[0], 'checksum_md5') model_details.references = [ model_card_module.Reference( reference=_property_value(trainers[0], 'pipeline_name')) ] return model_card
def test_annotate_eval_results_plots(self): slicing_metrics = [ ((('weekday', 0), ), { '': { '': { 'average_loss': { 'doubleValue': 0.07875693589448929 }, 'prediction/mean': { 'boundedValue': { 'value': 0.5100112557411194, 'lowerBound': 0.4100112557411194, 'upperBound': 0.6100112557411194, } }, 'average_loss_diff': {} } } }), ((('weekday', 1), ), { '': { '': { 'average_loss': { 'doubleValue': 4.4887189865112305 }, 'prediction/mean': { 'boundedValue': { 'value': 0.4839990735054016, 'lowerBound': 0.3839990735054016, 'upperBound': 0.5839990735054016, } }, 'average_loss_diff': {} } } }), ((('weekday', 2), ), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.3767518997192383, 'lowerBound': 0.1767518997192383, 'upperBound': 0.5767518997192383, } }, 'average_loss_diff': {} } } }), ((('gender', 'male'), ('age', 10)), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.3767518997192383, 'lowerBound': 0.1767518997192383, 'upperBound': 0.5767518997192383, } }, 'average_loss_diff': {} } } }), ( (('gender', 'female'), ('age', 20)), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'doubleValue': 0.3767518997192383 }, 'average_loss_diff': {}, '__ERROR__': { # CI not computed because only 16 samples # were non-empty. Expected 20. 'bytesValue': 'Q0kgbm90IGNvbXB1dGVkIGJlY2F1c2Ugb25seSAxNiBzYW1wbGVzIHdlcmUgbm9uLWVtcHR5LiBFeHBlY3RlZCAyMC4=' } } } }), ((), { '': { '': { 'average_loss': { 'doubleValue': 1.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.4767518997192383, 'lowerBound': 0.2767518997192383, 'upperBound': 0.6767518997192383, } }, 'average_loss_diff': {} } } }) ] eval_result = tfma.EvalResult(slicing_metrics=slicing_metrics, plots=None, attributions=None, config=None, data_location=None, file_format=None, model_location=None) model_card = model_card_module.ModelCard() graphics.annotate_eval_result_plots(model_card, eval_result) expected_metrics_names = { 'average_loss | weekday', 'prediction/mean | weekday', 'average_loss | gender, age', 'prediction/mean | gender, age' } self.assertSameElements(expected_metrics_names, [ g.name for g in model_card.quantitative_analysis.graphics.collection ]) for graph in model_card.quantitative_analysis.graphics.collection: logging.info('%s: %s', graph.name, graph.image) self.assertNotEmpty(graph.image, f'feature {graph.name} has empty plot')
def test_annotate_dataset_feature_statistics_plots(self): train_stats = text_format.Parse( """ datasets { features { path { step: "LDA_00" } type: FLOAT num_stats { histograms { buckets { low_value: 0.0 high_value: 100.0 sample_count: 10.0 } } histograms { buckets { low_value: 0.0 high_value: 50.0 sample_count: 4.0 } buckets { low_value: 50.0 high_value: 100.0 sample_count: 4.0 } type: QUANTILES } } } features { path { step: "LDA_01" } type: FLOAT num_stats { histograms { buckets { low_value: 0.0 high_value: 100.0 sample_count: 10.0 } } histograms { buckets { low_value: 0.0 high_value: 50.0 sample_count: 4.0 } buckets { low_value: 50.0 high_value: 100.0 sample_count: 4.0 } type: QUANTILES } } } features { path { step: "LDA_02" } type: FLOAT num_stats { histograms { buckets { low_value: 0.0 high_value: 100.0 sample_count: 10.0 } } histograms { buckets { low_value: 0.0 high_value: 50.0 sample_count: 4.0 } buckets { low_value: 50.0 high_value: 100.0 sample_count: 4.0 } type: QUANTILES } } } features { path { step: "LDA_03" } type: STRING bytes_stats { unique: 1 } } } """, statistics_pb2.DatasetFeatureStatisticsList()) eval_stats = text_format.Parse( """ datasets { features { path { step: "data_channel" } type: STRING string_stats { rank_histogram { buckets { label: 'News' sample_count: 1387.0 } buckets { label: 'Tech' sample_count: 3395.0 } buckets { label: 'Sports' sample_count: 2395.0 } } } } features { path { step: "date" } type: STRING string_stats { rank_histogram { buckets { label: '2014-12-10' sample_count: 40.0 } buckets { label: '2014-11-06' sample_count: 37.0 } } } } features { path { step: "slug" } type: STRING string_stats { rank_histogram { buckets { label: 'zynga-q3-earnings' sample_count: 1.0 } buckets { label: 'zumba-ad' sample_count: 1.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) model_card = model_card_module.ModelCard() graphics.annotate_dataset_feature_statistics_plots( model_card, train_stats, eval_stats) expected_plot_names_train = { 'counts | LDA_00', 'counts | LDA_01', 'counts | LDA_02' } expected_plot_names_eval = { 'counts | data_channel', 'counts | date', 'counts | slug' } self.assertSameElements([ g.name for g in model_card.model_parameters.data.train.graphics.collection ], expected_plot_names_train) self.assertSameElements([ g.name for g in model_card.model_parameters.data.eval.graphics.collection ], expected_plot_names_eval) graphs = model_card.model_parameters.data.train.graphics.collection + model_card.model_parameters.data.eval.graphics.collection for graph in graphs: logging.info('%s: %s', graph.name, graph.image) self.assertNotEmpty(graph.image, f'feature {graph.name} has empty plot')
def test_from_invalid_json(self): invalid_json_dict = {"model_name": "the_greatest_model"} with self.assertRaises(jsonschema.ValidationError): model_card.ModelCard().from_json(invalid_json_dict)
def test_from_json_and_to_json_with_all_fields(self): want_json = json.loads(_FULL_JSON) model_card_py = model_card.ModelCard() model_card_py.from_json(want_json) got_json = json.loads(model_card_py.to_json()) self.assertEqual(want_json, got_json)
def test_annotate_eval_results_plots(self): slicing_metrics = [((('weekday', 0),), { '': { '': { 'average_loss': { 'doubleValue': 0.07875693589448929 }, 'prediction/mean': { 'boundedValue': { 'value': 0.5100112557411194, 'lowerBound': 0.4100112557411194, 'upperBound': 0.6100112557411194, } }, 'average_loss_diff': {} } } }), ((('weekday', 1),), { '': { '': { 'average_loss': { 'doubleValue': 4.4887189865112305 }, 'prediction/mean': { 'boundedValue': { 'value': 0.4839990735054016, 'lowerBound': 0.3839990735054016, 'upperBound': 0.5839990735054016, } }, 'average_loss_diff': {} } } }), ((('weekday', 2),), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.3767518997192383, 'lowerBound': 0.1767518997192383, 'upperBound': 0.5767518997192383, } }, 'average_loss_diff': {} } } }), ((('gender', 'male'), ('age', 10)), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.3767518997192383, 'lowerBound': 0.1767518997192383, 'upperBound': 0.5767518997192383, } }, 'average_loss_diff': {} } } }), ((('gender', 'female'), ('age', 20)), { '': { '': { 'average_loss': { 'doubleValue': 2.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.3767518997192383, 'lowerBound': 0.1767518997192383, 'upperBound': 0.5767518997192383, } }, 'average_loss_diff': {} } } }), ((), { '': { '': { 'average_loss': { 'doubleValue': 1.092138290405273 }, 'prediction/mean': { 'boundedValue': { 'value': 0.4767518997192383, 'lowerBound': 0.2767518997192383, 'upperBound': 0.6767518997192383, } }, 'average_loss_diff': {} } } })] eval_result = tfma.EvalResult(slicing_metrics, None, None, None, None, None) model_card = model_card_module.ModelCard() graphics.annotate_eval_result_plots(model_card, eval_result) expected_metrics_names = { 'average_loss | weekday', 'prediction/mean | weekday', 'average_loss | gender, age', 'prediction/mean | gender, age' } self.assertSameElements( expected_metrics_names, [g.name for g in model_card.quantitative_analysis.graphics.collection]) for graph in model_card.quantitative_analysis.graphics.collection: logging.info('%s: %s', graph.name, graph.image) self.assertNotEmpty(graph.image, f'feature {graph.name} has empty plot')
def test_update_model_card_with_invalid_schema_version(self): mct = model_card_toolkit.ModelCardToolkit() model_card_invalid_version = model_card_module.ModelCard( schema_version='100.0.0') with self.assertRaisesRegex(ValueError, 'Cannot find schema version'): mct.update_model_card_json(model_card_invalid_version)