예제 #1
0
def config_benchmark_logger(flag_obj=None):
  """Config the global benchmark logger."""
  _logger_lock.acquire()
  try:
    global _benchmark_logger
    if not flag_obj:
      flag_obj = FLAGS

    if (not hasattr(flag_obj, "benchmark_logger_type") or
        flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"):
      _benchmark_logger = BaseBenchmarkLogger()
    elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger":
      _benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
    elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger":
      from benchmark import benchmark_uploader as bu  # pylint: disable=g-import-not-at-top
      bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project)
      _benchmark_logger = BenchmarkBigQueryLogger(
          bigquery_uploader=bq_uploader,
          bigquery_data_set=flag_obj.bigquery_data_set,
          bigquery_run_table=flag_obj.bigquery_run_table,
          bigquery_run_status_table=flag_obj.bigquery_run_status_table,
          bigquery_metric_table=flag_obj.bigquery_metric_table,
          run_id=str(uuid.uuid4()))
    else:
      raise ValueError("Unrecognized benchmark_logger_type: %s"
                       % flag_obj.benchmark_logger_type)

  finally:
    _logger_lock.release()
  return _benchmark_logger
예제 #2
0
    def setUp(self, mock_bigquery):
        self.mock_client = mock_bigquery.return_value
        self.mock_dataset = MagicMock(name="dataset")
        self.mock_table = MagicMock(name="table")
        self.mock_client.dataset.return_value = self.mock_dataset
        self.mock_dataset.table.return_value = self.mock_table
        self.mock_client.insert_rows_json.return_value = []

        self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
        self.benchmark_uploader._bq_client = self.mock_client

        self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
            json.dump({"name": "accuracy", "value": 1.0}, f)
            f.write("\n")
            json.dump({"name": "loss", "value": 0.5}, f)
            f.write("\n")
        with open(os.path.join(self.log_dir, "run.log"), "w") as f:
            json.dump({"model_name": "value"}, f)