def load_eval_config(output_path: Text) -> config.EvalConfig: """Loads eval config.""" path = os.path.join(output_path, _EVAL_CONFIG_FILE) if tf.io.gfile.exists(path): with tf.io.gfile.GFile(path, 'r') as f: pb = json_format.Parse(f.read(), config_pb2.EvalConfigAndVersion()) _check_version(pb.version, output_path) return pb.eval_config else: # Legacy suppport (to be removed in future). # The previous version did not include file extension. path = os.path.splitext(path)[0] serialized_record = six.next( tf.compat.v1.python_io.tf_record_iterator(path)) final_dict = pickle.loads(serialized_record) _check_version(final_dict, output_path) old_config = final_dict['eval_config'] slicing_specs = None if old_config.slice_spec: slicing_specs = [s.to_proto() for s in old_config.slice_spec] options = config.Options() options.compute_confidence_intervals.value = ( old_config.compute_confidence_intervals) options.k_anonymization_count.value = old_config.k_anonymization_count return config.EvalConfig( input_data_specs=[ config.InputDataSpec(location=old_config.data_location) ], model_specs=[config.ModelSpec(location=old_config.model_location)], output_data_specs=[ config.OutputDataSpec(default_location=output_path) ], slicing_specs=slicing_specs, options=options)
def _serialize_eval_config(eval_config: config.EvalConfig) -> Text: return json_format.MessageToJson( config_pb2.EvalConfigAndVersion(eval_config=eval_config, version=tfma_version.VERSION_STRING))