def main(_):
    config = flags.FLAGS

    gfile.makedirs(config.checkpoint_dir)
    if config.mode == "train":
        train(config)
    elif config.mode == "evaluate_pair":
        while True:
            checkpoint_path = utils.maybe_pick_models_to_evaluate(
                checkpoint_dir=config.checkpoint_dir)
            if checkpoint_path:
                evaluate_pair(
                    config=config,
                    batch_size=config.batch_size,
                    checkpoint_path=checkpoint_path,
                    data_dir=config.data_dir,
                    dataset=config.dataset,
                    num_examples_for_eval=config.num_examples_for_eval)
            else:
                logging.info(
                    "No models to evaluate found, sleeping for %d seconds",
                    EVALUATOR_SLEEP_PERIOD)
                time.sleep(EVALUATOR_SLEEP_PERIOD)
    else:
        raise Exception(
            "Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\""
            % (config.mode))
示例#2
0
def main(_):
  tf.debugging.set_log_device_placement(True)

  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpus:
    try:
      # Currently, memory growth needs to be the same across GPUs
      for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
      logical_gpus = tf.config.experimental.list_logical_devices('GPU')
      print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
      print('GPU memory is set to grow')
    except RuntimeError as e:
      print('Failed to set memory growth!!!!!!!!')
      # Memory growth must be set before GPUs have been initialized
      print(e)

  # Place tensors on the GPU
  with tf.device('/GPU:0'):

    config = flags.FLAGS

    gfile.makedirs(config.checkpoint_dir)
    if config.mode == "train":
      train(config)
    elif config.mode == "evaluate_pair":
      while True:
        checkpoint_path = utils.maybe_pick_models_to_evaluate(
            checkpoint_dir=config.checkpoint_dir)
        if checkpoint_path:
          evaluate_pair(
              config=config,
              batch_size=config.batch_size,
              checkpoint_path=checkpoint_path,
              data_dir=config.data_dir,
              dataset=config.dataset,
              num_examples_for_eval=config.num_examples_for_eval)
        else:
          logging.info("No models to evaluate found, sleeping for %d seconds",
                       EVALUATOR_SLEEP_PERIOD)
          time.sleep(EVALUATOR_SLEEP_PERIOD)
    else:
      raise Exception(
          "Unexpected mode %s, supported modes are \"train\" or \"evaluate_pair\""
          % (config.mode))