Example #1
0
def main():
    # Parse args and create config
    args, base_config, base_model, config_module = get_base_config(
        sys.argv[1:])

    if args.mode == "interactive_infer":
        raise ValueError(
            "Interactive infer is meant to be run from an IPython",
            "notebook not from run.py.")

    # Initilize Horovod
    if base_config['use_horovod']:
        import horovod.tensorflow as hvd
        hvd.init()
        if hvd.rank() == 0:
            deco_print("Using horovod")
    else:
        hvd = None

    restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)

    # Check logdir and create it if necessary
    checkpoint = check_logdir(args, base_config, restore_best_checkpoint)
    if args.enable_logs:
        if hvd is None or hvd.rank() == 0:
            old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
                args, base_config)
        base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')

    if args.mode == 'train' or args.mode == 'train_eval' or args.benchmark:
        if hvd is None or hvd.rank() == 0:
            if checkpoint is None or args.benchmark:
                deco_print("Starting training from scratch")
            else:
                deco_print(
                    "Restored checkpoint from {}. Resuming training".format(
                        checkpoint), )
    elif args.mode == 'eval' or args.mode == 'infer':
        if hvd is None or hvd.rank() == 0:
            deco_print("Loading model from {}".format(checkpoint))

    # Create model and train/eval/infer
    with tf.Graph().as_default():
        model = create_model(args, base_config, config_module, base_model, hvd)
        if args.mode == "train_eval":
            train(model[0], model[1], debug_port=args.debug_port)
        elif args.mode == "train":
            train(model, None, debug_port=args.debug_port)
        elif args.mode == "eval":
            evaluate(model, checkpoint)
        elif args.mode == "infer":
            infer(model, checkpoint, args.infer_output_file, args.use_trt)

    if args.enable_logs and (hvd is None or hvd.rank() == 0):
        sys.stdout = old_stdout
        sys.stderr = old_stderr
        stdout_log.close()
        stderr_log.close()
Example #2
0
def run():
    """This function executes a saved checkpoint for
  50 LibriSpeech dev clean files whose alignments are stored in
  calibration/target.json
  This function saves a pickle file with logits after running
  through the model as calibration/sample.pkl

  :return: None
  """
    args, base_config, base_model, config_module = get_calibration_config(
        sys.argv[1:])
    config_module["infer_params"]["data_layer_params"]["dataset_files"] = \
      ["calibration/sample.csv"]
    config_module["base_params"]["decoder_params"][
        "infer_logits_to_pickle"] = True
    load_model = base_config.get('load_model', None)
    restore_best_checkpoint = base_config.get('restore_best_checkpoint', False)
    base_ckpt_dir = check_base_model_logdir(load_model, args,
                                            restore_best_checkpoint)
    base_config['load_model'] = base_ckpt_dir

    # Check logdir and create it if necessary
    checkpoint = check_logdir(args, base_config, restore_best_checkpoint)

    # Initilize Horovod
    if base_config['use_horovod']:
        import horovod.tensorflow as hvd
        hvd.init()
        if hvd.rank() == 0:
            deco_print("Using horovod")
        from mpi4py import MPI
        MPI.COMM_WORLD.Barrier()
    else:
        hvd = None

    if args.enable_logs:
        if hvd is None or hvd.rank() == 0:
            old_stdout, old_stderr, stdout_log, stderr_log = create_logdir(
                args, base_config)
            base_config['logdir'] = os.path.join(base_config['logdir'], 'logs')

    if args.mode == 'infer':
        if hvd is None or hvd.rank() == 0:
            deco_print("Loading model from {}".format(checkpoint))
    else:
        print("Run in infer mode only")
        sys.exit()
    with tf.Graph().as_default():
        model = create_model(args, base_config, config_module, base_model, hvd,
                             checkpoint)
        infer(model, checkpoint, args.infer_output_file)

    return args.calibration_out