def main(): # Parse args and create config args, base_config, base_model, config_module = get_base_config( sys.argv[1:]) if args.mode == "interactive_infer": raise ValueError( "Interactive infer is meant to be run from an IPython", "notebook not from run.py.") # Initilize Horovod if base_config['use_horovod']: import horovod.tensorflow as hvd hvd.init() if hvd.rank() == 0: deco_print("Using horovod") else: hvd = None restore_best_checkpoint = base_config.get('restore_best_checkpoint', False) # Check logdir and create it if necessary checkpoint = check_logdir(args, base_config, restore_best_checkpoint) if args.enable_logs: if hvd is None or hvd.rank() == 0: old_stdout, old_stderr, stdout_log, stderr_log = create_logdir( args, base_config) base_config['logdir'] = os.path.join(base_config['logdir'], 'logs') if args.mode == 'train' or args.mode == 'train_eval' or args.benchmark: if hvd is None or hvd.rank() == 0: if checkpoint is None or args.benchmark: deco_print("Starting training from scratch") else: deco_print( "Restored checkpoint from {}. Resuming training".format( checkpoint), ) elif args.mode == 'eval' or args.mode == 'infer': if hvd is None or hvd.rank() == 0: deco_print("Loading model from {}".format(checkpoint)) # Create model and train/eval/infer with tf.Graph().as_default(): model = create_model(args, base_config, config_module, base_model, hvd) if args.mode == "train_eval": train(model[0], model[1], debug_port=args.debug_port) elif args.mode == "train": train(model, None, debug_port=args.debug_port) elif args.mode == "eval": evaluate(model, checkpoint) elif args.mode == "infer": infer(model, checkpoint, args.infer_output_file, args.use_trt) if args.enable_logs and (hvd is None or hvd.rank() == 0): sys.stdout = old_stdout sys.stderr = old_stderr stdout_log.close() stderr_log.close()
def run(): """This function executes a saved checkpoint for 50 LibriSpeech dev clean files whose alignments are stored in calibration/target.json This function saves a pickle file with logits after running through the model as calibration/sample.pkl :return: None """ args, base_config, base_model, config_module = get_calibration_config( sys.argv[1:]) config_module["infer_params"]["data_layer_params"]["dataset_files"] = \ ["calibration/sample.csv"] config_module["base_params"]["decoder_params"][ "infer_logits_to_pickle"] = True load_model = base_config.get('load_model', None) restore_best_checkpoint = base_config.get('restore_best_checkpoint', False) base_ckpt_dir = check_base_model_logdir(load_model, args, restore_best_checkpoint) base_config['load_model'] = base_ckpt_dir # Check logdir and create it if necessary checkpoint = check_logdir(args, base_config, restore_best_checkpoint) # Initilize Horovod if base_config['use_horovod']: import horovod.tensorflow as hvd hvd.init() if hvd.rank() == 0: deco_print("Using horovod") from mpi4py import MPI MPI.COMM_WORLD.Barrier() else: hvd = None if args.enable_logs: if hvd is None or hvd.rank() == 0: old_stdout, old_stderr, stdout_log, stderr_log = create_logdir( args, base_config) base_config['logdir'] = os.path.join(base_config['logdir'], 'logs') if args.mode == 'infer': if hvd is None or hvd.rank() == 0: deco_print("Loading model from {}".format(checkpoint)) else: print("Run in infer mode only") sys.exit() with tf.Graph().as_default(): model = create_model(args, base_config, config_module, base_model, hvd, checkpoint) infer(model, checkpoint, args.infer_output_file) return args.calibration_out