logger.info("Get model_blueprint from model directory.") # Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path. model_blueprint = utils.create_model_dir(model_dir, model_blueprint, stage=train_stage) if utils.is_main_training(): logger.info("Load egs to bunch.") # The dict [info] contains feat_dim and num_targets. bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params, loader_params) if utils.is_main_training(): logger.info("Create model from model blueprint.") # Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and # I don't want to change anything about extracting script when the model.py is changed. model_py = utils.create_model_from_py(model_blueprint) model = model_py.ResNetXvector(info["feat_dim"], info["num_targets"], **model_params) # If multi-GPU used, then batchnorm will be converted to synchronized batchnorm, which is important # to make peformance stable. # It will change nothing for single-GPU training. model = utils.convert_synchronized_batchnorm(model) if utils.is_main_training(): logger.info("Define optimizer and lr_scheduler.") optimizer = optim.get_optimizer(model, optimizer_params) lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper( optimizer, lr_scheduler_params) # Record params to model_dir
args = parser.parse_args() # Start try: # nnet_config include model_blueprint and model_creation if args.nnet_config != "": model_blueprint, model_creation = utils.read_nnet_config(args.nnet_config) elif args.model_blueprint is not None and args.model_creation is not None: model_blueprint = args.model_blueprint model_creation = args.model_creation else: raise ValueError("Expected nnet_config or (model_blueprint, model_creation) to exist.") model = utils.create_model_from_py(model_blueprint, model_creation) model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=False) # Select device model = utils.select_model_device(model, args.use_gpu, gpu_id=args.gpu_id) model.eval() with kaldi_io.open_or_fd(args.feats_rspecifier, "rb") as r, \ kaldi_io.open_or_fd(args.vectors_wspecifier, 'wb') as w: for line in r: # (key, rxfile, chunk_start, chunk_end) = line.decode().split(' ') # chunk=[chunk_start, chunk_end] # print("Process utterance for key {0}".format(key)) # feats = kaldi_io.read_mat(rxfile, chunk=chunk) (key, rxfile) = line.decode().split(' ')