def config_benchmark_logger(flag_obj=None): """Config the global benchmark logger.""" _logger_lock.acquire() try: global _benchmark_logger if not flag_obj: flag_obj = FLAGS if (not hasattr(flag_obj, "benchmark_logger_type") or flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"): _benchmark_logger = BaseBenchmarkLogger() elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger": _benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir) elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger": from benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project) _benchmark_logger = BenchmarkBigQueryLogger( bigquery_uploader=bq_uploader, bigquery_data_set=flag_obj.bigquery_data_set, bigquery_run_table=flag_obj.bigquery_run_table, bigquery_run_status_table=flag_obj.bigquery_run_status_table, bigquery_metric_table=flag_obj.bigquery_metric_table, run_id=str(uuid.uuid4())) else: raise ValueError("Unrecognized benchmark_logger_type: %s" % flag_obj.benchmark_logger_type) finally: _logger_lock.release() return _benchmark_logger
def setUp(self, mock_bigquery): self.mock_client = mock_bigquery.return_value self.mock_dataset = MagicMock(name="dataset") self.mock_table = MagicMock(name="table") self.mock_client.dataset.return_value = self.mock_dataset self.mock_dataset.table.return_value = self.mock_table self.mock_client.insert_rows_json.return_value = [] self.benchmark_uploader = benchmark_uploader.BigQueryUploader() self.benchmark_uploader._bq_client = self.mock_client self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) with open(os.path.join(self.log_dir, "metric.log"), "a") as f: json.dump({"name": "accuracy", "value": 1.0}, f) f.write("\n") json.dump({"name": "loss", "value": 0.5}, f) f.write("\n") with open(os.path.join(self.log_dir, "run.log"), "w") as f: json.dump({"model_name": "value"}, f)