def regression( msd, epochs, batch_size, train_input_glob, train_target_glob, val_input_glob, val_target_glob, weights_path, ): logging.info("Load training dataset") # Create train (always) and validation (only if specified) datasets. train_ds = mp.ImageDataset(train_input_glob, train_target_glob) train_dl = DataLoader(train_ds, batch_size, shuffle=True) if val_input_glob: logging.info("Load validation set") val_ds = mp.ImageDataset(val_input_glob, val_target_glob) val_dl = DataLoader(val_ds, batch_size, shuffle=False) else: logging.info("No validation set loaded") val_dl = None logging.info("Create network model") model = mp.MSDRegressionModel() train(model, epochs, train_dl, val_dl, weights_path)
def load_concat_data(inp, tar): if len(inp) == 1 and len(tar) == 1: inp = Path(inp[0]).expanduser().resolve() tar = Path(tar[0]).expanduser().resolve() train_ds = mp.ImageDataset(inp, tar) else: i = 0 for tig, ttg in zip(inp, tar): if i == 0: train_ds = mp.ImageDataset(tig, ttg) else: ds = mp.ImageDataset(tig, ttg) train_ds.input_stack.paths += ds.input_stack.paths train_ds.target_stack.paths += ds.target_stack.paths i += 1 return train_ds
print('loading model') model = mp.MSDRegressionModel(c_in, c_out, depth, width, dilations=dilations, loss=loss) constraint = 'sub12' alg = 'sirt' model_path = './programs/' + alg + 'models/' + constraint + '/modelparam_e99.torch' model.load(model_path) # Get image data_path = '/export/scratch3/jjow/' + alg + '_data/' + constraint + '/val/0rec0199.tiff' print('loading dataset') ds = mp.ImageDataset(data_path, data_path) dl = DataLoader(ds, 1, shuffle=False) for i, data in enumerate(dl): # Unpack data inp, _ = data output = model.net(inp.cuda()) output_np = output.detach().cpu().numpy() output_np = output_np.squeeze() # Saving image # All values above max get truncated save_dir = f'./transfer/enh_{alg}_{constraint}.png' im_max = 0.01 im_float = output_np / im_max im_float = np.clip(im_float, 0, 1)
# Trainings params epochs = 100 batch_size = 5 # Data params print('loading data...') base_dir = "/export/scratch3/jjow/sirt_data/" input_str = "dec12/" output_str = "full/" train_input_glob = base_dir + input_str + "*.tiff" train_target_glob = base_dir + output_str + "*.tiff" val_input_glob = base_dir + input_str + "val/*.tiff" val_target_glob = base_dir + output_str + "val/*.tiff" train_ds = mp.ImageDataset(train_input_glob, train_target_glob) train_dl = DataLoader(train_ds, batch_size, shuffle=True) val_ds = mp.ImageDataset(val_input_glob, val_target_glob) val_dl = DataLoader(val_ds, batch_size, shuffle=False) # Create model model = mp.MSDRegressionModel(c_in, c_out, depth, width, dilations=dilations, loss=loss) model.set_normalization(train_dl) best_val_err = np.inf val_err = 0 train_errs = [] val_errs = []