Esempio n. 1
0
def _load_net(path, subnet, subnet_params, it_net_params):
    subnet = subnet(**subnet_params).to(device)
    it_net = IterativeNet(subnet, **it_net_params).to(device)
    it_net.load_state_dict(torch.load(path, map_location=torch.device(device)))
    it_net.freeze()
    it_net.eval()
    return it_net
Esempio n. 2
0
    for tmp in load_dataset(config.set_params["path"], subset="val")
]

# ------ save hyperparameters -------
os.makedirs(train_params["save_path"][-1], exist_ok=True)
with open(os.path.join(train_params["save_path"][-1], "hyperparameters.txt"),
          "w") as file:
    for key, value in subnet_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in it_net_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_params.items():
        file.write(key + ": " + str(value) + "\n")
    file.write("train_phases" + ": " + str(train_phases) + "\n")

# ------ construct network and train -----
subnet = subnet(**subnet_params).to(device)
it_net = IterativeNet(subnet, **it_net_params).to(device)
for i in range(train_phases):
    train_params_cur = {}
    for key, value in train_params.items():
        train_params_cur[key] = (value[i] if isinstance(value,
                                                        (tuple,
                                                         list)) else value)

    print("Phase {}:".format(i + 1))
    for key, value in train_params_cur.items():
        print(key + ": " + str(value))

    it_net.train_on((Y_train, X_train), (Y_val, X_val), **train_params_cur)
Esempio n. 3
0
          "w") as file:
    for key, value in subnet_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in it_net_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in val_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    file.write("train_phases" + ": " + str(train_phases) + "\n")

# ------ construct network and train -----
subnet = subnet(**subnet_params).to(device)
it_net = IterativeNet(subnet, **it_net_params).to(device)
train_data = train_data("train", **train_data_params)
val_data = val_data("val", **val_data_params)

for i in range(train_phases):
    train_params_cur = {}
    for key, value in train_params.items():
        train_params_cur[key] = (value[i] if isinstance(value,
                                                        (tuple,
                                                         list)) else value)

    print("Phase {}:".format(i + 1))
    for key, value in train_params_cur.items():
        print(key + ": " + str(value))

    it_net.train_on(train_data, val_data, **train_params_cur)
Esempio n. 4
0
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in val_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    file.write("train_phases" + ": " + str(train_phases) + "\n")

# ------ construct network and train -----
subnet_tmp = subnet(**subnet_params).to(device)
it_net_tmp = IterativeNet(
    subnet_tmp, **{
        "num_iter": 8,
        "lam": 8 * [0.1],
        "lam_learnable": False,
        "final_dc": False,
        "resnet_factor": 1.0,
        "concat_mask": False,
        "multi_slice": False,
    }).to(device)
it_net_tmp.load_state_dict(
    torch.load(
        "results/radial_50_no_fs_unet_it_preinit_v1_train_phase_1/" +
        "model_weights.pt",
        map_location=torch.device(device),
    ))
subnet = it_net_tmp.subnet
it_net = IterativeNet(subnet, **it_net_params).to(device)

train_data = train_data("train", **train_data_params)
val_data = val_data("val", **val_data_params)
Esempio n. 5
0
    for key, value in train_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in val_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    file.write("train_phases" + ": " + str(train_phases) + "\n")

# ------ construct network and train -----
subnet_tmp = subnet(**subnet_params).to(device)
it_net_tmp = IterativeNet(
    subnet_tmp,
    **{
        "num_iter": 1,
        "lam": 0.0,
        "lam_learnable": False,
        "final_dc": False,
        "resnet_factor": 1.0,
        "operator": OpA,
        "inverter": inverter,
    }
).to(device)
it_net_tmp.load_state_dict(
    torch.load(
        "results/Fourier_UNet_jitter_v3_train_phase_2/model_weights.pt",
        map_location=torch.device(device),
    )
)
subnet = it_net_tmp.subnet
it_net = IterativeNet(subnet, **it_net_params).to(device)

train_data = train_data("train", **train_data_params)
}
subnet = UNet

it_net_params = {
    "num_iter": 8,
    "lam": 8 * [0.1],
    "lam_learnable": False,
    "final_dc": True,
    "resnet_factor": 1.0,
    "operator": OpA_m,
    "inverter": inverter,
}

# ------ construct network and load weights -----
subnet = subnet(**subnet_params).to(device)
it_net = IterativeNet(subnet, **it_net_params).to(device)
it_net.load_state_dict(
    torch.load(
        "results/Fourier_UNet_it_jit-nojit_train_phase_1/model_weights.pt",
        map_location=torch.device(device),
    )
)
it_net.freeze()
it_net.eval()

# ----- evaluation setup -----

# select samples
samples = range(150)
test_data = IPDataset("test", config.DATA_PATH)
Esempio n. 7
0
) as file:
    for key, value in subnet_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in it_net_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in train_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    for key, value in val_data_params.items():
        file.write(key + ": " + str(value) + "\n")
    file.write("train_phases" + ": " + str(train_phases) + "\n")

# ------ construct network and train -----
subnet = subnet(**subnet_params).to(device)
it_net = IterativeNet(subnet, **it_net_params).to(device)
it_net.load_state_dict(
    torch.load(
        "results/Fourier_UNet_it_jit-nojit_pre_train_phase_1/model_weights.pt",
        map_location=torch.device(device),
    )
)

train_data = train_data("train", **train_data_params)
val_data = val_data("val", **val_data_params)

for i in range(train_phases):
    train_params_cur = {}
    for key, value in train_params.items():
        train_params_cur[key] = (
            value[i] if isinstance(value, (tuple, list)) else value