def load_and_deserialize_plots(path): """Returns deserialized plots loaded from given path.""" result = [] for record in tf.python_io.tf_record_iterator(path): plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString( record) if plots_for_slice.HasField('plot_data'): result.append(( slicer.deserialize_slice_key(plots_for_slice.slice_key), # pytype: disable=wrong-arg-types plots_for_slice.plot_data)) if plots_for_slice.plots: result.append(( slicer.deserialize_slice_key(plots_for_slice.slice_key), # pytype: disable=wrong-arg-types plots_for_slice.plots)) return result
def load_and_deserialize_plots( path: Text) -> List[Tuple[slicer.SliceKeyType, Any]]: """Returns deserialized plots loaded from given path.""" result = [] for record in tf.compat.v1.python_io.tf_record_iterator(path): plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString( record) plots_map = {} if plots_for_slice.plots: plot_dict = _convert_proto_map_to_dict(plots_for_slice.plots) keys = list(plot_dict.keys()) # If there is only one label, choose it automatically. plot_data = plot_dict[keys[0]] if len(keys) == 1 else plot_dict plots_map[''] = {'': plot_data} elif plots_for_slice.HasField('plot_data'): plots_map[''] = { '': json_format.MessageToDict(plots_for_slice.plot_data) } if plots_for_slice.plot_keys_and_values: for kv in plots_for_slice.plot_keys_and_values: output_name = kv.key.output_name if output_name not in plots_map: plots_map[output_name] = {} sub_key_id = _get_sub_key_id( kv.key.sub_key) if kv.key.HasField('sub_key') else '' plots_map[output_name][sub_key_id] = json_format.MessageToDict( kv.value) result.append(( slicer.deserialize_slice_key(plots_for_slice.slice_key), # pytype: disable=wrong-arg-types plots_map)) return result
def load_and_deserialize_metrics( path: Text, model_name: Optional[Text] = None ) -> List[Tuple[slicer.SliceKeyType, Any]]: """Loads metrics from the given location and builds a metric map for it.""" result = [] for record in tf.compat.v1.python_io.tf_record_iterator(path): metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString( record) model_metrics_map = {} if metrics_for_slice.metrics: model_metrics_map[''] = { '': { '': _convert_proto_map_to_dict(metrics_for_slice.metrics) } } if metrics_for_slice.metric_keys: for key, value in zip(metrics_for_slice.metric_keys, metrics_for_slice.metric_values): current_model_name = key.model_name if current_model_name not in model_metrics_map: model_metrics_map[current_model_name] = {} output_name = key.output_name if output_name not in model_metrics_map[current_model_name]: model_metrics_map[current_model_name][output_name] = {} multi_class_metrics_map = model_metrics_map[ current_model_name][output_name] multi_class_key_id = _get_multi_class_key_id( key.multi_class_key) if key.HasField( 'multi_class_key') else '' if multi_class_key_id not in multi_class_metrics_map: multi_class_metrics_map[multi_class_key_id] = {} metric_name = key.name multi_class_metrics_map[multi_class_key_id][ metric_name] = json_format.MessageToDict(value) metrics_map = None keys = list(model_metrics_map.keys()) if model_name in model_metrics_map: # Use the provided model name if there is a match. metrics_map = model_metrics_map[model_name] elif model_name is None and len(keys) == 1: # Show result of the only model if no model name is specified. metrics_map = model_metrics_map[keys[0]] else: # No match found. raise ValueError('Fail to find metrics for model name: %s . ' 'Available model names are [%s]' % (model_name, ', '.join(keys))) result.append(( slicer.deserialize_slice_key(metrics_for_slice.slice_key), # pytype: disable=wrong-arg-types metrics_map)) return result
def load_and_deserialize_metrics( path): result = [] for record in tf.python_io.tf_record_iterator(path): metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(record) result.append(( slicer.deserialize_slice_key(metrics_for_slice.slice_key), # pytype: disable=wrong-arg-types metrics_for_slice.metrics)) return result
def load_and_deserialize_metrics(path: Text ) -> List[Tuple[slicer.SliceKeyType, Any]]: result = [] for record in tf.compat.v1.python_io.tf_record_iterator(path): metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(record) result.append(( slicer.deserialize_slice_key(metrics_for_slice.slice_key), # pytype: disable=wrong-arg-types metrics_for_slice.metrics)) return result
def load_and_deserialize_plots( path: Text) -> List[Tuple[slicer.SliceKeyType, Any]]: """Returns deserialized plots loaded from given path.""" result = [] for record in tf.python_io.tf_record_iterator(path): plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString(record) if plots_for_slice.HasField('plot_data'): if plots_for_slice.plots: raise RuntimeError('Both plots and plot_data are set.') # For backward compatibility, plots data geneated with old code are added # to the plots map with default key empty string. plots_for_slice.plots[''].CopyFrom(plots_for_slice.plot_data) result.append(( slicer.deserialize_slice_key(plots_for_slice.slice_key), # pytype: disable=wrong-arg-types plots_for_slice.plots)) return result
def testDeserializeSliceKey(self): slice_metrics = text_format.Parse( """ single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 1.0 } """, metrics_for_slice_pb2.SliceKey()) got_slice_key = slicer.deserialize_slice_key(slice_metrics) self.assertItemsEqual([('age', 5), ('language', b'english'), ('price', 1.0)], got_slice_key)