def predict_from_model(pre_trained_model, Xpred_file, shrink_factor=1, save_name=''): """ Predicting interface. 1. Retreive the flags 2. get data 3. initialize network 4. eval :param model_dir: The folder to retrieve the model :return: None """ # Retrieve the flag object print("This is doing the prediction for file", Xpred_file) print("Retrieving flag object for parameters") if (pre_trained_model.startswith("models")): eval_model = pre_trained_model[7:] print("after removing prefix models/, now model_dir is:", eval_model) flags = load_flags(pre_trained_model) # Get the pre-trained model flags.eval_model = eval_model # Reset the eval mode # Get the data, this part is useless in prediction but just for simplicity train_loader, test_loader = data_reader.read_data(flags) print("Making network now") # Make Network ntwk = Network(Backprop, flags, train_loader, test_loader, inference_mode=True, saved_model=flags.eval_model) print("number of trainable parameters is :") pytorch_total_params = sum(p.numel() for p in ntwk.model.parameters() if p.requires_grad) print(pytorch_total_params) # Evaluation process print("Start eval now:") pred_file, truth_file = ntwk.predict(Xpred_file, save_prefix=save_name + 'shrink_factor' + str(shrink_factor), shrink_factor=shrink_factor)
def predict(model_dir, Ytruth_file, multi_flag=False): """ Predict the output from given spectra """ print("Retrieving flag object for parameters") if (model_dir.startswith("models")): model_dir = model_dir[7:] print("after removing prefix models/, now model_dir is:", model_dir) if model_dir.startswith('/'): # It is a absolute path flags = helper_functions.load_flags(model_dir) else: flags = helper_functions.load_flags(os.path.join("models", model_dir)) flags.eval_model = model_dir # Reset the eval mode ntwk = Network(INN, flags, train_loader=None, test_loader=None, inference_mode=True, saved_model=flags.eval_model) print("number of trainable parameters is :") pytorch_total_params = sum(p.numel() for p in ntwk.model.parameters() if p.requires_grad) print(pytorch_total_params) # Evaluation process pred_file, truth_file = ntwk.predict(Ytruth_file) if 'Yang' not in flags.data_set: plotMSELossDistrib(pred_file, truth_file, flags)
def infer(pre_trained_model, Xpred_file, no_plot=True): # Retrieve the flag object print("This is doing the prediction for file", Xpred_file) print("Retrieving flag object for parameters") if (pre_trained_model.startswith("models")): eval_model = pre_trained_model[7:] print("after removing prefix models/, now model_dir is:", eval_model) flags = load_flags(pre_trained_model) # Get the pre-trained model flags.eval_model = pre_trained_model # Reset the eval mode flags.test_ratio = 0.1 # useless number # Get the data, this part is useless in prediction but just for simplicity train_loader, test_loader = data_reader.read_data(flags) print("Making network now") # Make Network ntwk = Network(Backprop, flags, train_loader, test_loader, inference_mode=True, saved_model=flags.eval_model) print("number of trainable parameters is :") pytorch_total_params = sum(p.numel() for p in ntwk.model.parameters() if p.requires_grad) print(pytorch_total_params) # Evaluation process print("Start eval now:") if not no_plot: # Plot the MSE distribution pred_file, truth_file = ntwk.predict(Xpred_file, no_save=False) flags.eval_model = pred_file.replace( '.', '_') # To make the plot name different plotMSELossDistrib(pred_file, truth_file, flags) else: pred_file, truth_file = ntwk.predict(Xpred_file, no_save=True) print("Evaluation finished") return pred_file, truth_file, flags