예제 #1
0
def main(_):
  with logger.benchmark_context(FLAGS), \
      mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
    mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
    if FLAGS.tpu:
      raise ValueError("NCF in Keras does not support TPU for now")
    run_ncf(FLAGS)
예제 #2
0
def main(_):
  with logger.benchmark_context(FLAGS), \
      mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
    mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
    if FLAGS.tpu:
      raise ValueError("NCF in Keras does not support TPU for now")
    run_ncf(FLAGS)
예제 #3
0
def main(_):
    with logger.benchmark_context(FLAGS), \
        mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
        mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
        run_ncf(FLAGS)
예제 #4
0
def main(_):
  with logger.benchmark_context(FLAGS), \
       mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
    mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
    run_ncf(FLAGS)
    mlperf_helper.stitch_ncf()
예제 #5
0
def main(_):
    with logger.benchmark_context(FLAGS), mlperf_helper.LOGGER(FLAGS.ml_perf):
        mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
        run_ncf(FLAGS)
        mlperf_helper.stitch_ncf()
예제 #6
0
def main(_):
    # Note: The async process must execute the following two steps in the
    #       following order BEFORE doing anything else:
    #       1) Write the alive file
    #       2) Wait for the flagfile to be written.
    global _log_file
    cache_paths = rconst.Paths(data_dir=flags.FLAGS.data_dir,
                               cache_id=flags.FLAGS.cache_id)
    write_alive_file(cache_paths=cache_paths)

    flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
    _parse_flagfile(flagfile)

    redirect_logs = flags.FLAGS.redirect_logs

    log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
    log_path = os.path.join(cache_paths.data_dir, log_file_name)
    if log_path.startswith("gs://") and redirect_logs:
        fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
        print("Unable to log to {}. Falling back to {}".format(
            log_path, fallback_log_file))
        log_path = fallback_log_file

    # This server is generally run in a subprocess.
    if redirect_logs:
        print("Redirecting output of data_async_generation.py process to {}".
              format(log_path))
        _log_file = open(log_path, "wt")  # Note: not tf.gfile.Open().
    try:
        log_msg("sys.argv: {}".format(" ".join(sys.argv)))

        if flags.FLAGS.seed is not None:
            np.random.seed(flags.FLAGS.seed)

        with mlperf_helper.LOGGER(enable=flags.FLAGS.ml_perf):
            mlperf_helper.set_ncf_root(
                os.path.split(os.path.abspath(__file__))[0])
            _generation_loop(
                num_workers=flags.FLAGS.num_workers,
                cache_paths=cache_paths,
                num_readers=flags.FLAGS.num_readers,
                num_neg=flags.FLAGS.num_neg,
                num_train_positives=flags.FLAGS.num_train_positives,
                num_items=flags.FLAGS.num_items,
                num_users=flags.FLAGS.num_users,
                epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
                train_batch_size=flags.FLAGS.train_batch_size,
                eval_batch_size=flags.FLAGS.eval_batch_size,
                deterministic=flags.FLAGS.seed is not None,
                match_mlperf=flags.FLAGS.ml_perf,
            )
    except KeyboardInterrupt:
        log_msg("KeyboardInterrupt registered.")
    except:
        traceback.print_exc(file=_log_file)
        raise
    finally:
        log_msg("Shutting down generation subprocess.")
        sys.stdout.flush()
        sys.stderr.flush()
        if redirect_logs:
            _log_file.close()
예제 #7
0
def main(_):
  # Note: The async process must execute the following two steps in the
  #       following order BEFORE doing anything else:
  #       1) Write the alive file
  #       2) Wait for the flagfile to be written.
  global _log_file
  cache_paths = rconst.Paths(
      data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
  write_alive_file(cache_paths=cache_paths)

  flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
  _parse_flagfile(flagfile)

  redirect_logs = flags.FLAGS.redirect_logs

  log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
  log_path = os.path.join(cache_paths.data_dir, log_file_name)
  if log_path.startswith("gs://") and redirect_logs:
    fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
    print("Unable to log to {}. Falling back to {}"
          .format(log_path, fallback_log_file))
    log_path = fallback_log_file

  # This server is generally run in a subprocess.
  if redirect_logs:
    print("Redirecting output of data_async_generation.py process to {}"
          .format(log_path))
    _log_file = open(log_path, "wt")  # Note: not tf.gfile.Open().
  try:
    log_msg("sys.argv: {}".format(" ".join(sys.argv)))

    if flags.FLAGS.seed is not None:
      np.random.seed(flags.FLAGS.seed)

    with mlperf_helper.LOGGER(
        enable=flags.FLAGS.output_ml_perf_compliance_logging):
      mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
      _generation_loop(
          num_workers=flags.FLAGS.num_workers,
          cache_paths=cache_paths,
          num_readers=flags.FLAGS.num_readers,
          num_neg=flags.FLAGS.num_neg,
          num_train_positives=flags.FLAGS.num_train_positives,
          num_items=flags.FLAGS.num_items,
          num_users=flags.FLAGS.num_users,
          epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
          num_cycles=flags.FLAGS.num_cycles,
          train_batch_size=flags.FLAGS.train_batch_size,
          eval_batch_size=flags.FLAGS.eval_batch_size,
          deterministic=flags.FLAGS.seed is not None,
          match_mlperf=flags.FLAGS.ml_perf,
      )
  except KeyboardInterrupt:
    log_msg("KeyboardInterrupt registered.")
  except:
    traceback.print_exc(file=_log_file)
    raise
  finally:
    log_msg("Shutting down generation subprocess.")
    sys.stdout.flush()
    sys.stderr.flush()
    if redirect_logs:
      _log_file.close()