import matplotlib as mpl import torch from data_management import load_dataset from networks import IterativeNet, Tiramisu from operators import TVAnalysis, get_tikhonov_matrix # --- load configuration ----- import config # isort:skip # ----- general setup ----- mpl.use("agg") device = torch.device("cuda:0") # ----- operators ----- OpA = config.meas_op(config.m, config.n, device=device, **config.meas_params) OpTV = TVAnalysis(config.n, device=device) # ----- build linear inverter ------ reg_fac = 2e-2 inverter = torch.nn.Linear(OpA.m, OpA.n, bias=False) inverter.weight.requires_grad = False inverter.weight.data = get_tikhonov_matrix(OpA, OpTV, reg_fac) # ----- network configuration ----- subnet_params = { "in_channels": 1, "out_channels": 1, "drop_factor": 0.0, "down_blocks": (5, 7, 9, 12, 15),
self.scale_hi = scale_hi self.trng = torch.Generator() if t_seed is not None: self.trng.manual_seed(t_seed) def __call__(self, inp): scale = self.scale_lo + (self.scale_hi - self.scale_lo) * torch.rand( inp.shape[:-1], generator=self.trng, ).to(inp.device) noise = torch.randn(inp.shape, generator=self.trng).to(inp.device) return inp + self.eta / np.sqrt( inp.shape[-1] ) * noise * scale.unsqueeze(-1) # ---- run data generation ----- if __name__ == "__main__": import config OpA = config.meas_op(config.m, config.n, **config.meas_params) np.random.seed(config.numpy_seed) torch.manual_seed(config.torch_seed) create_dataset( config.m, config.n, OpA, config.set_params, config.data_gen, config.data_params, )