Beispiel #1
0
Datei: 0.py Projekt: tkkcc/tnrd
def train(m, p=None):
    d = DataLoader(BSD3000(), o.batch_size, num_workers=o.num_workers)
    optimizer = torch.optim.Adam(m.parameters(), lr=o.lr)
    iter_num = len(d)
    num = 0
    losss = []
    stage = 1 if not p else p.stage + 1
    for epoch in range(o.epoch):
        for i in tqdm(d):
            g, y, k, s = [x.to(o.device) for x in i]
            x = y
            optimizer.zero_grad()
            out = m(x)
            log("out", out)
            loss = npsnr(out, g)
            loss.backward()
            optimizer.step()
            losss.append(loss.detach().item())
            assert not isnan(losss[-1])
            print("stage", stage, "epoch", epoch + 1)
            log("loss", mean(losss[-5:]))
            num += 1
            # if num > (o.epoch * iter_num - 4):
            if num % 50 == 1:
                show(torch.cat((y[0, 0], g[0, 0], out[0, 0]), 1),
                     # save=f"save/{stage:02}{epoch:02}.png",
                     )
    plt.clf()
    plt.plot(range(len(losss)), losss)
    plt.xlabel("batch")
    plt.ylabel("loss")
    plt.title(f"{iter_num} iter x {o.epoch} epoch")
    plt.savefig(f"save/{stage:02}loss.png")
Beispiel #2
0
def test(m):
    m.eval()
    with torch.no_grad():
        d = DataLoader(Sun(), 1)
        losss = []
        for i in tqdm(d):
            g, y, k, s = [x.to(o.device) for x in i]
            out = m([y, y, k, s])
            out = crop(out, k)
            out = center_crop(out, *g.shape[-2:])
            loss = npsnr(out, g)
            losss.append(-loss.detach().item())
            log("psnr", losss[-1])
            show(
                torch.cat(
                    (center_crop(y, *g.shape[-2:])[0, 0], g[0, 0], out[0, 0]),
                    1))
        log("psnr avg", sum(losss) / len(losss))
Beispiel #3
0
Datei: 4.py Projekt: tkkcc/prior
def train(m, p=None):
    d = DataLoader(BSD3000(noise=False, edgetaper=False),
                   o.batch_size,
                   num_workers=o.num_workers)
    optimizer = torch.optim.Adam(m.parameters(), lr=o.lr)
    iter_num = len(d)
    num = 0
    losss = []
    stage = 1 if not p else p.stage + 1
    for epoch in range(o.epoch):
        for i in tqdm(d):
            g, y, k, s = [x.to(o.device) for x in i]
            k = k.flip(1, 2)
            x = torch.tensor(y, requires_grad=True)
            if p:
                with torch.no_grad():
                    x = p([x, y, k, s])
            optimizer.zero_grad()
            out = m([x, y, k, s])
            log("out", out)
            out = center_crop(out, *g.shape[-2:])
            loss = npsnr(out, g)
            loss.backward()
            optimizer.step()
            losss.append(loss.detach().item())
            assert not isnan(losss[-1])
            print("stage", stage, "epoch", epoch + 1)
            log("loss", mean(losss[-5:]))
            num += 1
            # if num > (o.epoch * iter_num - 4):
            if num % 20 == 0:
                show(
                    torch.cat((center_crop(
                        y, *g.shape[-2:])[0, 0], g[0, 0], out[0, 0]), 1),
                    save=f"save/{stage:02}{epoch:02}.png",
                )
    plt.clf()
    plt.plot(range(len(losss)), losss)
    plt.xlabel("batch")
    plt.ylabel("loss")
    plt.title(f"{iter_num} iter x {o.epoch} epoch")
    plt.savefig(f"save/{stage:02}loss.png")
Beispiel #4
0
def train(m, p=None):
    d = DataLoader(BSD3000(),
                   o.batch_size,
                   num_workers=o.num_workers,
                   shuffle=True)
    optimizer = torch.optim.Adam(m.parameters(), lr=o.lr)
    iter_num = len(d)
    num = 0
    losss = []
    mse = torch.nn.MSELoss()
    stage = 1 if not p else p.stage + 1
    for epoch in range(o.epoch):
        for i in tqdm(d):
            g, y, k, s = [x.to(o.device) for x in i]
            k = k.flip(1, 2)
            x = y
            if p:
                with torch.no_grad():
                    x = p([x, y, k, s])
            optimizer.zero_grad()
            out = m([x, y, k, s])
            log("out", out)
            out = crop(out, k)
            loss = npsnr(out, g)
            loss.backward()
            optimizer.step()
            losss.append(loss.detach().item())
            assert not isnan(losss[-1])
            print("stage", stage, "epoch", epoch + 1)
            log("loss", mean(losss[-5:]))
            num += 1
            if num > (o.epoch * iter_num - 4):
                # if num % 6 == 1:
                show(torch.cat((crop(y, k)[0, 0], g[0, 0], out[0, 0]), 1),
                     # save=f"save/{stage:02}{epoch:02}.png",
                     )
    plt.clf()
    plt.plot([i + 1 for i in range(len(losss))], losss)
    plt.xlabel("batch")
    plt.ylabel("loss")
    plt.title(f"{iter_num} iter x {o.epoch} epoch")
    plt.savefig(f"save/{stage:02}loss.png")