Example #1
0
import matplotlib as mpl
import torch

from data_management import load_dataset
from networks import IterativeNet, Tiramisu
from operators import TVAnalysis, get_tikhonov_matrix

# --- load configuration -----
import config  # isort:skip

# ----- general setup -----
mpl.use("agg")
device = torch.device("cuda:0")

# ----- operators -----
OpA = config.meas_op(config.m, config.n, device=device, **config.meas_params)
OpTV = TVAnalysis(config.n, device=device)

# ----- build linear inverter  ------
reg_fac = 2e-2

inverter = torch.nn.Linear(OpA.m, OpA.n, bias=False)
inverter.weight.requires_grad = False
inverter.weight.data = get_tikhonov_matrix(OpA, OpTV, reg_fac)

# ----- network configuration -----
subnet_params = {
    "in_channels": 1,
    "out_channels": 1,
    "drop_factor": 0.0,
    "down_blocks": (5, 7, 9, 12, 15),
Example #2
0
        self.scale_hi = scale_hi
        self.trng = torch.Generator()
        if t_seed is not None:
            self.trng.manual_seed(t_seed)

    def __call__(self, inp):
        scale = self.scale_lo + (self.scale_hi - self.scale_lo) * torch.rand(
            inp.shape[:-1], generator=self.trng,
        ).to(inp.device)
        noise = torch.randn(inp.shape, generator=self.trng).to(inp.device)
        return inp + self.eta / np.sqrt(
            inp.shape[-1]
        ) * noise * scale.unsqueeze(-1)


# ---- run data generation -----
if __name__ == "__main__":
    import config

    OpA = config.meas_op(config.m, config.n, **config.meas_params)
    np.random.seed(config.numpy_seed)
    torch.manual_seed(config.torch_seed)
    create_dataset(
        config.m,
        config.n,
        OpA,
        config.set_params,
        config.data_gen,
        config.data_params,
    )