def _load_net(path, subnet, subnet_params, it_net_params): subnet = subnet(**subnet_params).to(device) it_net = IterativeNet(subnet, **it_net_params).to(device) it_net.load_state_dict(torch.load(path, map_location=torch.device(device))) it_net.freeze() it_net.eval() return it_net
"resnet_factor": 1.0, "operator": OpA_m, "inverter": inverter, } # ------ construct network and load weights ----- subnet = subnet(**subnet_params).to(device) it_net = IterativeNet(subnet, **it_net_params).to(device) it_net.load_state_dict( torch.load( "results/Fourier_UNet_it_jit-nojit_train_phase_1/model_weights.pt", map_location=torch.device(device), ) ) it_net.freeze() it_net.eval() # ----- evaluation setup ----- # select samples samples = range(150) test_data = IPDataset("test", config.DATA_PATH) # dynamic range for plotting v_min = 0.0 v_max = 0.9 # ----- plotting ----- def _implot(sub, im, vmin=v_min, vmax=v_max): if im.shape[-3] == 2: # complex image