Ejemplo n.º 1
0
    def run_model(self, train_config, eval_config):
        with tf.Graph().as_default() as g:
            train_model = base_model(params=train_config,
                                     mode="train",
                                     hvd=None)
            train_model.compile()
            eval_model = base_model(params=eval_config, mode="eval", hvd=None)
            eval_model.compile(force_var_reuse=True)

            train(train_model, eval_model, hvd=None)
            saver = tf.train.Saver()
            checkpoint = tf.train.latest_checkpoint(
                train_model.params['logdir'])
            with self.test_session(g, use_gpu=True) as sess:
                saver.restore(sess, checkpoint)

                weights = sess.run(tf.trainable_variables())
                loss = sess.run(
                    train_model.loss,
                    train_model.data_layer.next_batch_feed_dict(),
                )
                eval_loss = sess.run(
                    eval_model.loss,
                    eval_model.data_layer.next_batch_feed_dict(),
                )
                weights_new = sess.run(tf.trainable_variables())

                # checking that the weights has not changed from just computing the loss
                for w, w_new in zip(weights, weights_new):
                    npt.assert_allclose(w, w_new)
            eval_dict = evaluate(eval_model, checkpoint)
        return loss, eval_loss, eval_dict
Ejemplo n.º 2
0
  def run_model(self, train_config, eval_config, hvd=None):
    with tf.Graph().as_default() as g:
      # pylint: disable=not-callable
      train_model = self.base_model(params=train_config, mode="train", hvd=hvd)
      train_model.compile()
      eval_model = self.base_model(params=eval_config, mode="eval", hvd=hvd)
      eval_model.compile(force_var_reuse=True)

      train(train_model, eval_model)
      saver = tf.train.Saver()
      checkpoint = tf.train.latest_checkpoint(train_model.params['logdir'])
      with self.test_session(g, use_gpu=True) as sess:
        saver.restore(sess, checkpoint)
        sess.run([train_model.get_data_layer(i).iterator.initializer
                  for i in range(train_model.num_gpus)])
        sess.run([eval_model.get_data_layer(i).iterator.initializer
                  for i in range(eval_model.num_gpus)])

        weights = sess.run(tf.trainable_variables())
        loss = sess.run(train_model.loss)
        eval_losses = sess.run(eval_model.eval_losses)
        eval_loss = np.mean(eval_losses)
        weights_new = sess.run(tf.trainable_variables())

        # checking that the weights has not changed from
        # just computing the loss
        for w, w_new in zip(weights, weights_new):
          npt.assert_allclose(w, w_new)
      eval_dict = evaluate(eval_model, checkpoint)
    return loss, eval_loss, eval_dict
Ejemplo n.º 3
0
  def run_model(self, train_config, eval_config):
    with tf.Graph().as_default() as g:
      train_model = base_model(params=train_config, mode="train", hvd=None)
      train_model.compile()
      eval_model = base_model(params=eval_config, mode="eval", hvd=None)
      eval_model.compile(force_var_reuse=True)

      train(train_model, eval_model)
      saver = tf.train.Saver()
      checkpoint = tf.train.latest_checkpoint(train_model.params['logdir'])
      with self.test_session(g, use_gpu=True) as sess:
        saver.restore(sess, checkpoint)
        sess.run([train_model.get_data_layer(i).iterator.initializer
                  for i in range(train_model.num_gpus)])
        sess.run([eval_model.get_data_layer(i).iterator.initializer
                  for i in range(eval_model.num_gpus)])

        weights = sess.run(tf.trainable_variables())
        loss = sess.run(train_model.loss)
        eval_losses = sess.run(eval_model.eval_losses)
        eval_loss = np.mean(eval_losses)
        weights_new = sess.run(tf.trainable_variables())

        # checking that the weights has not changed from just computing the loss
        for w, w_new in zip(weights, weights_new):
          npt.assert_allclose(w, w_new)
      eval_dict = evaluate(eval_model, checkpoint)
    return loss, eval_loss, eval_dict
Ejemplo n.º 4
0
def main():
    # Parse args and create config
    args, base_config, base_model, config_module = get_base_config(
        sys.argv[1:])

    if args.mode == "interactive_infer":
        raise ValueError(
            "Interactive infer is meant to be run from an IPython",
            "notebook not from run.py.")

    # Initilize Horovod
    if base_config['use_horovod']:
        import horovod.tensorflow as hvd
        hvd.init()
        if hvd.rank() == 0:
            deco_print("Using horovod")
    else:
        hvd = None

    restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)

    # Check logdir and create it if necessary
    checkpoint = check_logdir(args, base_config, restore_best_checkpoint)
    if args.enable_logs:
        if hvd is None or hvd.rank() == 0:
            old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
                args, base_config)
        base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')

    if args.mode == 'train' or args.mode == 'train_eval' or args.benchmark:
        if hvd is None or hvd.rank() == 0:
            if checkpoint is None or args.benchmark:
                deco_print("Starting training from scratch")
            else:
                deco_print(
                    "Restored checkpoint from {}. Resuming training".format(
                        checkpoint), )
    elif args.mode == 'eval' or args.mode == 'infer':
        if hvd is None or hvd.rank() == 0:
            deco_print("Loading model from {}".format(checkpoint))

    # Create model and train/eval/infer
    with tf.Graph().as_default():
        model = create_model(args, base_config, config_module, base_model, hvd)
        if args.mode == "train_eval":
            train(model[0], model[1], debug_port=args.debug_port)
        elif args.mode == "train":
            train(model, None, debug_port=args.debug_port)
        elif args.mode == "eval":
            evaluate(model, checkpoint)
        elif args.mode == "infer":
            infer(model, checkpoint, args.infer_output_file, args.use_trt)

    if args.enable_logs and (hvd is None or hvd.rank() == 0):
        sys.stdout = old_stdout
        sys.stderr = old_stderr
        stdout_log.close()
        stderr_log.close()
Ejemplo n.º 5
0
    def infer_test(self):
        train_config, infer_config = self.prepare_config()
        train_config['num_epochs'] = 250
        infer_config['batch_size_per_gpu'] = 4

        with tf.Graph().as_default() as g:
            with self.test_session(g, use_gpu=True) as sess:
                gpus = get_available_gpus()

        if len(gpus) > 1:
            infer_config['num_gpus'] = 2
        else:
            infer_config['num_gpus'] = 1

        with tf.Graph().as_default():
            # pylint: disable=not-callable
            train_model = self.base_model(params=train_config,
                                          mode="train",
                                          hvd=None)
            train_model.compile()
            train(train_model, None)

        with tf.Graph().as_default():
            # pylint: disable=not-callable
            infer_model = self.base_model(params=infer_config,
                                          mode="infer",
                                          hvd=None)
            infer_model.compile()

            print(train_model.params['logdir'])
            output_file = os.path.join(train_model.params['logdir'],
                                       'infer_out.csv')
            infer(
                infer_model,
                tf.train.latest_checkpoint(train_model.params['logdir']),
                output_file,
            )
            pred_csv = pd.read_csv(output_file)
            true_csv = pd.read_csv(
                'open_seq2seq/test_utils/toy_speech_data/toy_data.csv', )
            for pred_row, true_row in zip(pred_csv.as_matrix(),
                                          true_csv.as_matrix()):
                # checking file name
                self.assertEqual(pred_row[0], true_row[0])
                # checking prediction: no more than 5 chars difference
                self.assertLess(levenshtein(pred_row[-1], true_row[-1]), 5)
Ejemplo n.º 6
0
  def test_infer(self):
    train_config, infer_config = self.prepare_config()
    train_config['num_epochs'] = 200
    infer_config['batch_size_per_gpu'] = 4

    with tf.Graph().as_default() as g:
      with self.test_session(g, use_gpu=True) as sess:
        gpus = get_available_gpus()

    if len(gpus) > 1:
      infer_config['num_gpus'] = 2
    else:
      infer_config['num_gpus'] = 1

    with tf.Graph().as_default():
      train_model = base_model(params=train_config, mode="train", hvd=None)
      train_model.compile()
      train(train_model, None)

    with tf.Graph().as_default():
      infer_model = base_model(params=infer_config, mode="infer", hvd=None)
      infer_model.compile()

      print(train_model.params['logdir'])
      output_file = os.path.join(train_model.params['logdir'], 'infer_out.csv')
      infer(
        infer_model,
        tf.train.latest_checkpoint(train_model.params['logdir']),
        output_file,
      )
      pred_csv = pd.read_csv(output_file)
      true_csv = pd.read_csv(
        'open_seq2seq/test_utils/toy_speech_data/toy_data.csv',
      )
      for pred_row, true_row in zip(pred_csv.as_matrix(), true_csv.as_matrix()):
        # checking file name
        self.assertEqual(pred_row[0], true_row[0])
        # checking prediction
        self.assertEqual(pred_row[-1], true_row[-1])
Ejemplo n.º 7
0
    def test_infer(self):
        train_config, infer_config = self.prepare_config()
        train_config['num_epochs'] = 200
        infer_config['batch_size_per_gpu'] = 4
        infer_config['num_gpus'] = 1

        with tf.Graph().as_default():
            train_model = base_model(params=train_config,
                                     mode="train",
                                     hvd=None)
            train_model.compile()
            train(train_model, None, hvd=None)

        with tf.Graph().as_default():
            infer_model = base_model(params=infer_config,
                                     mode="infer",
                                     hvd=None)
            infer_model.compile()

            print(train_model.params['logdir'])
            output_file = os.path.join(train_model.params['logdir'],
                                       'infer_out.csv')
            infer(
                infer_model,
                tf.train.latest_checkpoint(train_model.params['logdir']),
                output_file,
            )
            pred_csv = pd.read_csv(output_file)
            true_csv = pd.read_csv(
                'open_seq2seq/test_utils/toy_speech_data/toy_data.csv', )
            for pred_row, true_row in zip(pred_csv.as_matrix(),
                                          true_csv.as_matrix()):
                # checking file name
                self.assertEqual(pred_row[0], true_row[0])
                # checking prediction
                self.assertEqual(pred_row[-1], true_row[-1])
Ejemplo n.º 8
0
def main():
  parser = argparse.ArgumentParser(description='Experiment parameters')
  parser.add_argument("--config_file", required=True,
                      help="Path to the configuration file")
  parser.add_argument("--mode", default='train',
                      help="Could be \"train\", \"eval\", "
                           "\"train_eval\" or \"infer\"")
  parser.add_argument("--infer_output_file",
                      help="Path to the output of inference")
  parser.add_argument('--continue_learning', dest='continue_learning',
                      action='store_true', help="whether to continue learning")
  parser.add_argument('--no_dir_check', dest='no_dir_check',
                      action='store_true',
                      help="whether to check that everything is correct "
                           "with log directory")
  parser.add_argument('--benchmark', dest='benchmark', action='store_true',
                      help='automatic config change for benchmarking')
  parser.add_argument('--bench_steps', type=int, default='20',
                      help='max_steps for benchmarking')
  parser.add_argument('--bench_start', type=int,
                      help='first step to start counting time for benchmarking')
  parser.add_argument('--debug_port', type=int,
                      help='run TensorFlow in debug mode on specified port')
  args, unknown = parser.parse_known_args()

  if args.mode not in ['train', 'eval', 'train_eval', 'infer']:
    raise ValueError("Mode has to be one of "
                     "['train', 'eval', 'train_eval', 'infer']")
  config_module = runpy.run_path(args.config_file, init_globals={'tf': tf})

  base_config = config_module.get('base_params', None)
  if base_config is None:
    raise ValueError('base_config dictionary has to be '
                     'defined in the config file')
  base_model = config_module.get('base_model', None)
  if base_model is None:
    raise ValueError('base_config class has to be defined in the config file')

  # after we read the config, trying to overwrite some of the properties
  # with command line arguments that were passed to the script
  parser_unk = argparse.ArgumentParser()
  for pm, value in base_config.items():
    if type(value) is int or type(value) is float or type(value) is str or \
       type(value) is bool:
      parser_unk.add_argument('--' + pm, default=value, type=type(value))
  config_update = parser_unk.parse_args(unknown)
  base_config.update(vars(config_update))

  train_config = copy.deepcopy(base_config)
  eval_config = copy.deepcopy(base_config)
  infer_config = copy.deepcopy(base_config)

  if base_config['use_horovod']:
    if args.mode == "infer" or args.mode == "eval":
      raise NotImplementedError("Inference or evaluation on horovod "
                                "is not supported yet")
    if args.mode == "train_eval":
      deco_print("Evaluation during training is not yet supported on horovod, "
                 "defaulting to just doing mode=\"train\"")
      args.mode = "train"
    import horovod.tensorflow as hvd
    hvd.init()
    if hvd.rank() == 0:
      deco_print("Using horovod")
  else:
    hvd = None

  if args.mode == 'train' or args.mode == 'train_eval':
    if 'train_params' in config_module:
      train_config.update(copy.deepcopy(config_module['train_params']))
    if hvd is None or hvd.rank() == 0:
      deco_print("Training config:")
      pprint.pprint(train_config)
  if args.mode == 'eval' or args.mode == 'train_eval':
    if 'eval_params' in config_module:
      eval_config.update(copy.deepcopy(config_module['eval_params']))
      eval_config['gpu_ids'] = [eval_config['num_gpus'] - 1]
      if 'num_gpus' in eval_config:
        del eval_config['num_gpus']
    if hvd is None or hvd.rank() == 0:
      deco_print("Evaluation can only be run on one GPU. "
                 "Setting num_gpus to 1 for eval model")
      deco_print("Evaluation config:")
      pprint.pprint(eval_config)
  if args.mode == "infer":
    if args.infer_output_file is None:
      raise ValueError("\"infer_output_file\" command line parameter is "
                       "required in inference mode")
    infer_config.update(copy.deepcopy(config_module['infer_params']))
    deco_print("Inference can be run only on one GPU. Setting num_gpus to 1")
    infer_config['num_gpus'] = 1
    deco_print("Inference config:")
    pprint.pprint(infer_config)

  # checking that everything is correct with log directory
  logdir = base_config['logdir']
  if args.benchmark:
    args.no_dir_check = True
  try:
    if args.mode == 'train' or args.mode == 'train_eval':
      if os.path.isfile(logdir):
        raise IOError("There is a file with the same name as \"logdir\" "
                      "parameter. You should change the log directory path "
                      "or delete the file to continue.")

      # check if "logdir" directory exists and non-empty
      if os.path.isdir(logdir) and os.listdir(logdir) != []:
        checkpoint = tf.train.latest_checkpoint(logdir)
        if not args.continue_learning:
          raise IOError("Log directory is not empty. If you want to continue "
                        "learning, you should provide "
                        "\"--continue_learning\" flag")
        if checkpoint is None:
          raise IOError("There is no valid TensorFlow checkpoint in the "
                        "log directory. Can't restore variables.")
      else:
        if args.continue_learning:
          raise IOError("The log directory is empty or does not exist. "
                        "You should probably not provide "
                        "\"--continue_learning\" flag?")
        checkpoint = None
    elif args.mode == 'infer' or args.mode == 'eval':
      if os.path.isdir(logdir) and os.listdir(logdir) != []:
        checkpoint = tf.train.latest_checkpoint(logdir)
        if checkpoint is None:
          raise IOError("There is no valid TensorFlow checkpoint in the "
                        "{} directory. Can't load model".format(logdir))
      else:
        raise IOError(
          "{} does not exist or is empty, can't restore model".format(logdir)
        )
  except IOError as e:
    if args.no_dir_check:
      print("Warning: {}".format(e))
      print("Resuming operation since no_dir_check argument was provided")
    else:
      raise

  if args.benchmark:
    deco_print("Adjusting config for benchmarking")
    train_config['print_samples_steps'] = None
    train_config['print_loss_steps'] = 1
    train_config['summary_steps'] = None
    train_config['save_checkpoint_steps'] = None
    train_config['logdir'] = str("")
    if 'num_epochs' in train_config:
      del train_config['num_epochs']
    train_config['max_steps'] = args.bench_steps
    if args.bench_start:
      train_config['bench_start'] = args.bench_start
    elif 'bench_start' not in train_config:
      train_config['bench_start'] = 10  # default value

    deco_print("New benchmarking config:")
    pprint.pprint(train_config)
    args.mode = "train"
    checkpoint = None

  if args.mode == 'train' or args.mode == 'train_eval':
    if hvd is None or hvd.rank() == 0:
      if checkpoint is None:
        deco_print("Starting training from scratch")
      else:
        deco_print(
          "Restored checkpoint from {}. Resuming training".format(checkpoint),
        )
  elif args.mode == 'eval' or args.mode == 'infer':
    deco_print("Loading model from {}".format(checkpoint))

  with tf.Graph().as_default():
    if args.mode == 'train':
      train_model = base_model(params=train_config, mode="train", hvd=hvd)
      train_model.compile()
      train(train_model, None, hvd=hvd, debug_port=args.debug_port)
    elif args.mode == 'train_eval':
      train_model = base_model(params=train_config, mode="train", hvd=hvd)
      train_model.compile()
      eval_model = base_model(params=eval_config, mode="eval", hvd=hvd)
      eval_model.compile(force_var_reuse=True)
      train(train_model, eval_model, hvd=hvd, debug_port=args.debug_port)
    elif args.mode == "eval":
      eval_model = base_model(params=eval_config, mode="eval", hvd=hvd)
      eval_model.compile()
      evaluate(eval_model, checkpoint)
    elif args.mode == "infer":
      infer_model = base_model(params=infer_config, mode="infer", hvd=hvd)
      infer_model.compile()
      infer(infer_model, checkpoint, args.infer_output_file)
Ejemplo n.º 9
0
def main():
  parser = argparse.ArgumentParser(description='Experiment parameters')
  parser.add_argument("--config_file", required=True,
                      help="Path to the configuration file")
  parser.add_argument("--mode", default='train',
                      help="Could be \"train\", \"eval\", "
                           "\"train_eval\" or \"infer\"")
  parser.add_argument("--infer_output_file",
                      help="Path to the output of inference")
  parser.add_argument('--continue_learning', dest='continue_learning',
                      action='store_true', help="whether to continue learning")
  parser.add_argument('--no_dir_check', dest='no_dir_check',
                      action='store_true',
                      help="whether to check that everything is correct "
                           "with log directory")
  parser.add_argument('--benchmark', dest='benchmark', action='store_true',
                      help='automatic config change for benchmarking')
  parser.add_argument('--bench_steps', type=int, default='20',
                      help='max_steps for benchmarking')
  parser.add_argument('--bench_start', type=int,
                      help='first step to start counting time for benchmarking')
  parser.add_argument('--debug_port', type=int,
                      help='run TensorFlow in debug mode on specified port')
  parser.add_argument('--enable_logs', dest='enable_logs', action='store_true',
                      help='whether to log output, git info, cmd args, etc.')
  args, unknown = parser.parse_known_args()

  if args.mode not in ['train', 'eval', 'train_eval', 'infer']:
    raise ValueError("Mode has to be one of "
                     "['train', 'eval', 'train_eval', 'infer']")
  config_module = runpy.run_path(args.config_file, init_globals={'tf': tf})

  base_config = config_module.get('base_params', None)
  if base_config is None:
    raise ValueError('base_config dictionary has to be '
                     'defined in the config file')
  base_model = config_module.get('base_model', None)
  if base_model is None:
    raise ValueError('base_config class has to be defined in the config file')

  # after we read the config, trying to overwrite some of the properties
  # with command line arguments that were passed to the script
  parser_unk = argparse.ArgumentParser()
  for pm, value in flatten_dict(base_config).items():
    if type(value) == int or type(value) == float or \
       isinstance(value, string_types):
      parser_unk.add_argument('--' + pm, default=value, type=type(value))
    elif type(value) == bool:
      parser_unk.add_argument('--' + pm, default=value, type=ast.literal_eval)
  config_update = parser_unk.parse_args(unknown)
  nested_update(base_config, nest_dict(vars(config_update)))

  # checking that everything is correct with log directory
  logdir = base_config['logdir']
  if args.benchmark:
    args.no_dir_check = True
  try:
    if args.enable_logs:
      ckpt_dir = os.path.join(logdir, 'logs')
    else:
      ckpt_dir = logdir
    if args.mode == 'train' or args.mode == 'train_eval':
      if os.path.isfile(logdir):
        raise IOError("There is a file with the same name as \"logdir\" "
                      "parameter. You should change the log directory path "
                      "or delete the file to continue.")

      # check if "logdir" directory exists and non-empty
      if os.path.isdir(logdir) and os.listdir(logdir) != []:
        if not args.continue_learning:
          raise IOError("Log directory is not empty. If you want to continue "
                        "learning, you should provide "
                        "\"--continue_learning\" flag")
        checkpoint = tf.train.latest_checkpoint(ckpt_dir)
        if checkpoint is None:
          raise IOError(
              "There is no valid TensorFlow checkpoint in the "
              "{} directory. Can't load model".format(ckpt_dir)
          )
      else:
        if args.continue_learning:
          raise IOError("The log directory is empty or does not exist. "
                        "You should probably not provide "
                        "\"--continue_learning\" flag?")
        checkpoint = None
    elif args.mode == 'infer' or args.mode == 'eval':
      if os.path.isdir(logdir) and os.listdir(logdir) != []:
        checkpoint = tf.train.latest_checkpoint(ckpt_dir)
        if checkpoint is None:
          raise IOError(
              "There is no valid TensorFlow checkpoint in the "
              "{} directory. Can't load model".format(ckpt_dir)
          )
      else:
        raise IOError(
            "{} does not exist or is empty, can't restore model".format(
                ckpt_dir
            )
        )
  except IOError as e:
    if args.no_dir_check:
      print("Warning: {}".format(e))
      print("Resuming operation since no_dir_check argument was provided")
    else:
      raise

  if base_config['use_horovod']:
    import horovod.tensorflow as hvd
    hvd.init()
    if hvd.rank() == 0:
      deco_print("Using horovod")
  else:
    hvd = None

  if args.enable_logs:
    if hvd is None or hvd.rank() == 0:
      if not os.path.exists(logdir):
        os.makedirs(logdir)

      tm_suf = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
      shutil.copy(
          args.config_file,
          os.path.join(logdir, 'config_{}.py'.format(tm_suf)),
      )

      with open(os.path.join(logdir, 'cmd-args_{}.log'.format(tm_suf)),
                'w') as f:
        f.write(" ".join(sys.argv))

      with open(os.path.join(logdir, 'git-info_{}.log'.format(tm_suf)),
                'w') as f:
        f.write('commit hash: {}'.format(get_git_hash()))
        f.write(get_git_diff())

      old_stdout = sys.stdout
      old_stderr = sys.stderr
      stdout_log = open(
          os.path.join(logdir, 'stdout_{}.log'.format(tm_suf)), 'a', 1
      )
      stderr_log = open(
          os.path.join(logdir, 'stderr_{}.log'.format(tm_suf)), 'a', 1
      )
      sys.stdout = Logger(sys.stdout, stdout_log)
      sys.stderr = Logger(sys.stderr, stderr_log)

    base_config['logdir'] = os.path.join(logdir, 'logs')

  train_config = copy.deepcopy(base_config)
  eval_config = copy.deepcopy(base_config)
  infer_config = copy.deepcopy(base_config)

  if args.mode == 'train' or args.mode == 'train_eval':
    if 'train_params' in config_module:
      nested_update(train_config, copy.deepcopy(config_module['train_params']))
    if hvd is None or hvd.rank() == 0:
      deco_print("Training config:")
      pprint.pprint(train_config)
  if args.mode == 'eval' or args.mode == 'train_eval':
    if 'eval_params' in config_module:
      nested_update(eval_config, copy.deepcopy(config_module['eval_params']))
    if hvd is None or hvd.rank() == 0:
      deco_print("Evaluation config:")
      pprint.pprint(eval_config)
  if args.mode == "infer":
    if args.infer_output_file is None:
      raise ValueError("\"infer_output_file\" command line parameter is "
                       "required in inference mode")
    if "infer_params" in config_module:
      nested_update(infer_config, copy.deepcopy(config_module['infer_params']))

    if hvd is None or hvd.rank() == 0:
      deco_print("Inference config:")
      pprint.pprint(infer_config)

  if args.benchmark:
    deco_print("Adjusting config for benchmarking")
    train_config['print_samples_steps'] = None
    train_config['print_loss_steps'] = 1
    train_config['save_summaries_steps'] = None
    train_config['save_checkpoint_steps'] = None
    train_config['logdir'] = str("")
    if 'num_epochs' in train_config:
      del train_config['num_epochs']
    train_config['max_steps'] = args.bench_steps
    if args.bench_start:
      train_config['bench_start'] = args.bench_start
    elif 'bench_start' not in train_config:
      train_config['bench_start'] = 10  # default value

    if hvd is None or hvd.rank() == 0:
      deco_print("New benchmarking config:")
      pprint.pprint(train_config)
    args.mode = "train"
    checkpoint = None

  if args.mode == 'train' or args.mode == 'train_eval':
    if hvd is None or hvd.rank() == 0:
      if checkpoint is None:
        deco_print("Starting training from scratch")
      else:
        deco_print(
            "Restored checkpoint from {}. Resuming training".format(checkpoint),
        )
  elif args.mode == 'eval' or args.mode == 'infer':
    if hvd is None or hvd.rank() == 0:
      deco_print("Loading model from {}".format(checkpoint))

  with tf.Graph().as_default():
    if args.mode == 'train':
      train_model = base_model(params=train_config, mode="train", hvd=hvd)
      train_model.compile()
      train(train_model, None, debug_port=args.debug_port)
    elif args.mode == 'train_eval':
      train_model = base_model(params=train_config, mode="train", hvd=hvd)
      train_model.compile()
      eval_model = base_model(params=eval_config, mode="eval", hvd=hvd)
      eval_model.compile(force_var_reuse=True)
      train(train_model, eval_model, debug_port=args.debug_port)
    elif args.mode == "eval":
      eval_model = base_model(params=eval_config, mode="eval", hvd=hvd)
      eval_model.compile()
      evaluate(eval_model, checkpoint)
    elif args.mode == "infer":
      infer_model = base_model(params=infer_config, mode="infer", hvd=hvd)
      infer_model.compile()
      infer(infer_model, checkpoint, args.infer_output_file)

  if args.enable_logs and (hvd is None or hvd.rank() == 0):
    sys.stdout = old_stdout
    sys.stderr = old_stderr
    stdout_log.close()
    stderr_log.close()