예제 #1
0
파일: main.py 프로젝트: adler-j/msd_pytorch
def regression(
    msd,
    epochs,
    batch_size,
    train_input_glob,
    train_target_glob,
    val_input_glob,
    val_target_glob,
    weights_path,
):
    logging.info("Load training dataset")
    # Create train (always) and validation (only if specified) datasets.
    train_ds = mp.ImageDataset(train_input_glob, train_target_glob)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True)
    if val_input_glob:
        logging.info("Load validation set")
        val_ds = mp.ImageDataset(val_input_glob, val_target_glob)
        val_dl = DataLoader(val_ds, batch_size, shuffle=False)
    else:
        logging.info("No validation set loaded")
        val_dl = None

    logging.info("Create network model")
    model = mp.MSDRegressionModel()
    train(model, epochs, train_dl, val_dl, weights_path)
예제 #2
0
def load_concat_data(inp, tar):
    if len(inp) == 1 and len(tar) == 1:
        inp = Path(inp[0]).expanduser().resolve()
        tar = Path(tar[0]).expanduser().resolve()
        train_ds = mp.ImageDataset(inp, tar)
    else:
        i = 0
        for tig, ttg in zip(inp, tar):
            if i == 0:
                train_ds = mp.ImageDataset(tig, ttg)
            else:
                ds = mp.ImageDataset(tig, ttg)
                train_ds.input_stack.paths += ds.input_stack.paths
                train_ds.target_stack.paths += ds.target_stack.paths
            i += 1

    return train_ds
예제 #3
0
print('loading model')
model = mp.MSDRegressionModel(c_in,
                              c_out,
                              depth,
                              width,
                              dilations=dilations,
                              loss=loss)
constraint = 'sub12'
alg = 'sirt'
model_path = './programs/' + alg + 'models/' + constraint + '/modelparam_e99.torch'
model.load(model_path)

# Get image
data_path = '/export/scratch3/jjow/' + alg + '_data/' + constraint + '/val/0rec0199.tiff'
print('loading dataset')
ds = mp.ImageDataset(data_path, data_path)
dl = DataLoader(ds, 1, shuffle=False)

for i, data in enumerate(dl):
    # Unpack data
    inp, _ = data
    output = model.net(inp.cuda())
    output_np = output.detach().cpu().numpy()
    output_np = output_np.squeeze()

    # Saving image
    # All values above max get truncated
    save_dir = f'./transfer/enh_{alg}_{constraint}.png'
    im_max = 0.01
    im_float = output_np / im_max
    im_float = np.clip(im_float, 0, 1)
예제 #4
0
# Trainings params
epochs = 100
batch_size = 5

# Data params
print('loading data...')
base_dir = "/export/scratch3/jjow/sirt_data/"
input_str = "dec12/"
output_str = "full/"
train_input_glob = base_dir + input_str + "*.tiff"
train_target_glob = base_dir + output_str + "*.tiff"

val_input_glob = base_dir + input_str + "val/*.tiff"
val_target_glob = base_dir + output_str + "val/*.tiff"
train_ds = mp.ImageDataset(train_input_glob, train_target_glob) 
train_dl = DataLoader(train_ds, batch_size, shuffle=True)

val_ds = mp.ImageDataset(val_input_glob, val_target_glob)
val_dl = DataLoader(val_ds, batch_size, shuffle=False)

# Create model
model = mp.MSDRegressionModel(c_in, c_out, depth, width, dilations=dilations, loss=loss)
model.set_normalization(train_dl)

best_val_err = np.inf
val_err = 0

train_errs = []
val_errs = []