Exemplo n.º 1
0
def setup_flower_shutter(rawdata, savedir):

    savedir = Path(savedir)
    wipe_dirs(savedir)
    init_dirs(savedir)

    data = imread(rawdata)
    data = normalize3(data, 2, 99.6)  ## normalize across all dims?

    d = SimpleNamespace()
    d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU)
    d.net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt'
        ))

    # d.net.apply(init_weights);
    d.savedir = savedir

    d.xs = torch.from_numpy(data).float()
    d.xs = d.xs.reshape(100, 4, 256, 4, 256,
                        1).permute(0, 1, 3, 5, 2, 4)  #.reshape((-1,256,256))
    d.ys = d.xs.mean(0)

    io.imsave(savedir / 'xs.png',
              collapse2(d.xs[0, :, :, 0].numpy(), "12yx", "1y,2x"))
    io.imsave(savedir / 'ys.png',
              collapse2(d.ys[:, :, 0].numpy(), "12yx", "1y,2x"))
    d.cuda = False

    return d
Exemplo n.º 2
0
def datagen(params={}, savedir=None):
    data = []

    times = np.r_[:190]

    for i in times:
        img = imread(
            f'/lustre/projects/project-broaddus/devseg_data/raw/celegans_isbi/Fluo-N3DH-CE/01/t{i:03d}.tif'
        )

        pmin, pmax = np.random.uniform(1, 3), np.random.uniform(99.5, 99.8)
        img = normalize3(img, pmin, pmax).astype(np.float32, copy=False)

        slicelist = []

        def random_patch():
            ss = random_slice(img.shape, (32, 64, 64))
            ## select patches with interesting content. 0.02 is chosen by manual inspection.
            while img[ss].mean() < 0.03:
                ss = random_slice(img.shape, (32, 64, 64))
            x = img[ss].copy()
            slicelist.append(ss)

            ## augment
            # noiselevel = 0.2
            # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape)
            # for d in [0,1,2]:
            #   if np.random.rand() < 0.5:
            #     x  = np.flip(x,d)

            return (x, )

        data.append([random_patch() for _ in range(10)])  #ts(xys)czyx

    data = np.array(data)

    print("data.shape: ", data.shape)

    if savedir:
        rgb = collapse2(data[:, :, :, 16], 'tscyx', 'ty,sx,c')[..., [0, 0, 0]]
        rgb = normalize3(rgb)
        io.imsave(savedir / 'data_xy_cele.png', rgb)
        rgb = collapse2(data[:, :, :, :, 32], 'tsczx', 'tz,sx,c')[...,
                                                                  [0, 0, 0]]
        rgb = normalize3(rgb)
        io.imsave(savedir / 'data_xz_cele.png', rgb)
        np.savez_compressed(savedir / 'data_cele.npz', data)
        pklsave(slicelist, savedir / 'slicelist_cele.pkl')

    return data
Exemplo n.º 3
0
def setup(savedir):

    savedir = Path(savedir)
    wipe_dirs(savedir)
    init_dirs(savedir)

    # data = cl_datagen2.datagen_self_sup(s=4, savedir=savedir)
    # data = cl_datagen2.datagen_all_kinds(savedir=savedir)

    data = np.load(
        '/lustre/projects/project-broaddus/denoise_experiments/cele/e01/data_cele.npz'
    )['arr_0']
    # data = datagen(savedir=savedir)
    data = collapse2(data[None], 'rtsczyx', 'c,ts,r,z,y,x')[0]

    d = SimpleNamespace()
    d.net = torch_models.Unet2(16, [[1], [1]], finallayer=nn.ReLU).cuda()
    d.net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit3D.pt'
        ))
    # d.net.apply(init_weights);
    d.savedir = savedir
    # torch.save(d.net.state_dict(), '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/net_randinit3D.pt')

    d.x1_all = torch.from_numpy(data).float().cuda()
    return d
Exemplo n.º 4
0
def datagen(savedir=None):

    # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/flower.tif')[:10]
    img = imread(
        f'/lustre/projects/project-broaddus/denoise_experiments/flower/e02/pred_flower.tif'
    )[:10]
    # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/shutterclosed.tif')[0]

    print(img.shape)
    # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8)
    pmin, pmax = 2, 99.6
    print(f"pmin = {pmin}; pmax = {pmax}")
    img = normalize3(img, pmin, pmax).astype(np.float32, copy=False)
    data = img.reshape((-1, 4, 256, 4, 256)).transpose(
        (0, 1, 3, 2, 4)).reshape((-1, 1, 256, 256))

    # patch_size = (256,256)
    # slicelist = []
    # def random_patch():
    #   ss = random_slice(img.shape, patch_size)

    #   ## select patches with interesting content. FIXME
    #   while img[ss].mean() < 0.0:
    #     ss = random_slice(img.shape, patch_size)
    #   x  = img[ss].copy()
    #   slicelist.append(ss)

    #   ## augment
    #   # noiselevel = 0.2
    #   # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape)
    #   # for d in [0,1,2]:
    #   #   if np.random.rand() < 0.5:
    #   #     x  = np.flip(x,d)

    #   return (x,)

    # data = np.array([random_patch() for _ in range(24)])

    # data = np.load('../../devseg_data/cl_datagen/d003/data.npz')
    print("data.shape: ", data.shape)

    #SCZYX
    if savedir:
        rgb = collapse2(data[:, :], 'scyx', 's,y,x,c')[..., [0, 0, 0]]
        rgb = normalize3(rgb)
        rgb = plotgrid([rgb], 10)
        io.imsave(savedir / 'data_xy_flower.png', rgb)
        np.savez_compressed(savedir / 'data_flower.npz',
                            data=data,
                            pmin=pmin,
                            pmax=pmax)
        # pklsave(slicelist, savedir/'slicelist2.pkl')

    dg = SimpleNamespace()
    dg.data = data
    dg.pmin = pmin
    dg.pmax = pmax

    return dg
Exemplo n.º 5
0
def train(d, ta=None, end_epoch=300, mask_shape=[1, 2, 3, 4]):
    if ta is None: ta = init_training_artifacts()

    ## set up const variables necessary for training
    batch_size = 4
    inds = np.arange(0, d.x1_all.shape[0])
    patch_size = d.x1_all.shape[2:]
    d.w1_all = torch.ones(d.x1_all.shape).float()

    ## set up variables for monitoring training
    # d.eg_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone()
    d.eg_xs = d.x1_all[[0, 3, 5, 12]].clone()
    d.xs_fft = torch.fft((d.eg_xs - d.eg_xs.mean())[..., None][..., [0, 0]],
                         2).norm(p=2, dim=-1)
    d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2)))
    lossdist = torch.zeros(d.x1_all.shape[0]) - 2

    ## move everything to cuda
    d.net = d.net.cuda()
    d.x1_all = d.x1_all.cuda()
    d.w1_all = d.w1_all.cuda()
    d.xs_fft = d.xs_fft.cuda()
    d.eg_xs = d.eg_xs.cuda()

    opt = torch.optim.Adam(d.net.parameters(), lr=2e-5)

    plt.figure()
    for e in range(ta.e, end_epoch + 1):
        ta.e = e
        np.random.shuffle(inds)
        lossdist[...] = -1
        print(f"\r epoch {e}", end="")

        for b in range(ceil(d.x1_all.shape[0] / batch_size)):
            idxs = inds[b * batch_size:(b + 1) * batch_size]
            x1 = d.x1_all[idxs]  #.cuda()
            w1 = d.w1_all[idxs]  #.cuda()

            def random_pixel_mask():
                n = int(np.prod(patch_size) * 0.02)
                x_inds = np.random.randint(0, patch_size[1], n)
                y_inds = np.random.randint(0, patch_size[0], n)
                # z_inds = np.random.randint(0,32,64*64*1)
                ma = np.zeros(patch_size)
                ma[y_inds, x_inds] = 2
                return ma

            def sparse_3set_mask():
                "build random mask for small number of central pixels"
                n = int(np.prod(patch_size) * 0.02)
                x_inds = np.random.randint(0, patch_size[1], n)
                y_inds = np.random.randint(0, patch_size[0], n)
                ma = np.zeros(patch_size)

                # ma = binary_dilation(ma)

                for i in mask_shape:
                    m = x_inds - i >= 0
                    ma[y_inds[m], x_inds[m] - i] = 1
                    m = x_inds + i < patch_size[1]
                    ma[y_inds[m], x_inds[m] + i] = 1
                # for i in [1]:
                #   m = y_inds-i >= 0;            ma[y_inds[m]-i,x_inds[m]] = 1
                #   m = y_inds+i < patch_size[0]; ma[y_inds[m]+i,x_inds[m]] = 1

                ma = ma.astype(np.uint8)
                ma[y_inds, x_inds] = 2

                return ma

            def checkerboard_mask():
                ma = np.indices(patch_size).transpose((1, 2, 0))
                ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0
                ma = 2 * ma
                if e % 2 == 1: ma = 2 - ma
                return ma

            ma = sparse_3set_mask()
            # ipdb.set_trace()
            # return ma

            ## apply mask to input
            w1[:, :] = torch.from_numpy(ma.astype(np.float)).cuda()
            x1_damaged = x1.clone()
            x1_damaged[w1 > 0] = torch.rand(x1.shape).cuda()[w1 > 0]

            y1p = d.net(x1_damaged)

            dims = (1, 2, 3)  ## all dims except batch

            if False:
                dx = 0.15 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1])
                dy = 0.15 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1])
                dy = 0.25 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1])
                dz = 0.05 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1])
                c0, c1, c2 = 0.0, 0.15, 1.0
                potential = 2e2 * (
                    (y1p - c0)**2 *
                    (y1p - c2)**2)  ## rough locations for three classes
                resid = torch.abs(y1p - x1)**2
                loss_per_patch = resid.mean(dims) + dx.mean(
                    dims
                )  #+ dy.mean(dims) + dz.mean(dims) #+ potential.mean(dims)

            tm = (w1 == 2).float()  ## target mask
            loss_per_patch = (tm * torch.abs(y1p - x1)**2).sum(dims) / tm.sum(
                dims)  # + dx.mean(dims) + dy.mean(dims) #+ dz.mean(dims)
            # ipdb.set_trace()

            # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs()
            # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs()
            lossdist[idxs] = loss_per_patch.detach().cpu()
            loss = loss_per_patch.mean()
            ta.losses.append(float(loss))

            opt.zero_grad()
            loss.backward()
            opt.step()

        ## predict on examples and save predictions as images
        with torch.no_grad():
            example_yp = d.net(d.eg_xs)
            # d.xs_fft = d.xs_fft/d.xs_fft.max()
            yp_fft = torch.fft(
                (example_yp - example_yp.mean())[..., None][..., [0, 0]],
                2).norm(p=2, dim=-1)  #.cpu().detach().numpy()
            yp_fft = torch.from_numpy(
                np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda()
            # yp_fft = yp_fft/yp_fft.max()

            rgb = torch.stack([
                d.eg_xs, w1[[0] * len(d.eg_xs)] / 2, d.xs_fft, example_yp,
                yp_fft
            ], 0).cpu().detach().numpy()
            arr = rgb.copy()
            # type,samples,channels,y,x
            rgb = normalize3(rgb, axs=(1, 2, 3, 4))
            rgb[[2, 4]] = normalize3(rgb[[2, 4]],
                                     pmin=0,
                                     pmax=99.0,
                                     axs=(1, 2, 3, 4))
            # return rgb
            # remove channels and permute
            rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx')
            # arr = collapse2(arr[:,:,0],'tsyx','sy,tx')

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                if e % 10 == 0:
                    io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb)
                if e % 100 == 0:
                    np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr)

        ## plot the loss after each epoch
        ta.lossdists.append(lossdist.numpy().copy())
        batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size)
        x_axis = np.arange(len(ta.losses)) / batches_per_epoch
        plt.clf()
        plt.plot(x_axis, ta.losses)
        # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses))
        plt.yscale('log')
        plt.xlabel(f'1 epoch = {batches_per_epoch} batches')
        plt.savefig(d.savedir / f'loss.png', dpi=300)

        ## and save the model state
        if e % 50 == 0:
            torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt')

    pklsave(ta.losses, d.savedir / f'losses.pkl')
    torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt')
    return ta
Exemplo n.º 6
0
def train(d, ta=None, end_epoch=300, already_on_cuda=False):
    if ta is None: ta = init_training_artifacts()

    ## setup const variables necessary for training
    batch_size = 4
    inds = np.arange(0, d.xs.shape[0])
    patch_size = d.xs.shape[4:]
    # xs = d.xs.reshape((100,4,256,4,256)).permute((0,1,3,2,4)) #.reshape((-1,256,256))
    # ys = d.xs.mean(0).reshape((4,256,4,256)).permute((0,2,1,3))
    d.ws = torch.ones(d.xs.shape).float()

    ## set up variables for monitoring training
    # d.example_xs = d.xs[inds[::floor(np.sqrt(len(inds)))]].clone()
    d.example_xs = d.xs[[0, 3, 5, 12], 0, 0].reshape(-1, 1, 256,
                                                     256).clone().cpu()
    d.xs_fft = torch.fft(
        (d.example_xs - d.example_xs.mean())[..., None][..., [0, 0]],
        2).norm(p=2, dim=-1)
    d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2)))
    lossdist = torch.zeros(d.xs.shape[0]) - 2

    ## move vars to gpu
    # if d.cuda is False:
    d.net = d.net.cuda()
    d.xs = d.xs.cuda()
    d.ys = d.ys.cuda()
    d.xs_fft = d.xs_fft.cuda()
    d.example_xs = d.example_xs.cuda()
    d.ws = d.ws.cuda()

    ## initialize optimizer (must be done after moving data to gpu ?)
    opt = torch.optim.Adam(d.net.parameters(), lr=2e-4)

    plt.figure()
    for e in range(ta.e, end_epoch + 1):
        ta.e = e
        np.random.shuffle(inds)
        ta.lossdists.append(lossdist.numpy().copy())
        lossdist[...] = -1
        print(f"\r epoch {e}", end="")

        for b in range(ceil(d.xs.shape[0] / batch_size)):
            idxs = inds[b * batch_size:(b + 1) * batch_size]
            x1 = d.xs[idxs]
            w1 = d.ws[idxs]
            # y1   = d.ys[idxs]

            x1 = x1.reshape(-1, 1, 256, 256)
            y1p = d.net(x1)
            # x1  = x1.reshape(-1,4,4,256,256)
            # y1p = y1p.reshape(-1,4,4,256,256)
            y1p = y1p.reshape(4, 4, 4, 1, 256, 256)

            # ipdb.set_trace()

            dims = (1, 2, 3, 4, 5)  ## all dims except batch

            # ipdb.set_trace()
            loss_per_patch = ((y1p - d.ys)**2).mean(dims)

            # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs()
            # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs()
            lossdist[idxs] = loss_per_patch.detach().cpu()
            loss = loss_per_patch.mean()
            ta.losses.append(float(loss))

            opt.zero_grad()
            loss.backward()
            opt.step()

        ## predict on examples and save each epoch
        if e % 10 == 0:
            with torch.no_grad():
                example_yp = d.net(d.example_xs)

                ## compute fft from predictions
                yp_fft = torch.fft(
                    (example_yp - example_yp.mean())[..., None][..., [0, 0]],
                    2).norm(p=2, dim=-1)  #.cpu().detach().numpy()
                ## shift frequency domain s.t. zer freq is at center of array
                yp_fft = torch.from_numpy(
                    np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda()

                ## stack (real space, -weights-, real fft, predictions, and prediction fft) along a new dimension
                rgb = torch.stack([d.example_xs, d.xs_fft, example_yp, yp_fft],
                                  0).cpu().detach().numpy()
                arr = rgb.copy()
                ## first normalize each type to [0,1] independently
                rgb = normalize3(rgb,
                                 axs=(1, 2, 3,
                                      4))  # dims=type,samples,channels,y,x
                ## then normalize fft's and real-space dims separately
                rgb[[1, 3]] = normalize3(rgb[[1, 3]],
                                         pmin=0,
                                         pmax=99.0,
                                         axs=(1, 2, 3, 4))

                ## remove channels and permute into a 2D image
                rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx')

                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    if e % 10 == 0:
                        io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb)
                    if e % 100 == 0:
                        np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr)

        ## plot loss
        batches_per_epoch = ceil(d.xs.shape[0] / batch_size)
        epochs = np.arange(len(ta.losses)) / batches_per_epoch
        plt.clf()
        plt.plot(epochs, ta.losses)
        # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses))
        plt.yscale('log')
        plt.xlabel(f'1 epoch = {batches_per_epoch} batches')
        plt.savefig(d.savedir / f'loss.png', dpi=300)

        ## save model weights
        if e % 100 == 0:
            torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt')

    pklsave(ta.losses, d.savedir / f'losses.pkl')
    torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt')
    return ta
Exemplo n.º 7
0
def train(d, ta=None, end_epoch=301):
    if ta is None: ta = init_training_artifacts()
    batch_size = 4
    inds = np.arange(0, d.x1_all.shape[0])
    # example_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone()
    example_xs = d.x1_all[[0, 3, 5, 12]].clone()
    xs_fft = torch.fft((example_xs - example_xs.mean())[..., None][...,
                                                                   [0, 0]],
                       2).norm(p=2, dim=-1)
    xs_fft = torch.from_numpy(np.fft.fftshift(xs_fft.cpu(),
                                              axes=(-1, -2))).cuda()

    opt = torch.optim.Adam(d.net.parameters(), lr=2e-5)
    opt2 = torch.optim.Adam(d.net2.parameters(), lr=2e-5)
    lossdist = torch.zeros(d.x1_all.shape[0]) - 2

    patch_size = d.x1_all.shape[2:]

    plt.figure()
    for e in range(ta.e, end_epoch):
        ta.e = e
        np.random.shuffle(inds)
        ta.lossdists.append(lossdist.numpy().copy())
        lossdist[...] = -1
        print(f"\r epoch {e}", end="")

        for b in range(ceil(d.x1_all.shape[0] / batch_size)):
            idxs = inds[b * batch_size:(b + 1) * batch_size]
            x1 = d.x1_all[idxs]  #.cuda()

            def random_pixel_mask():
                n = int(np.prod(patch_size) * 0.02)
                x_inds = np.random.randint(0, patch_size[1], n)
                y_inds = np.random.randint(0, patch_size[0], n)
                # z_inds = np.random.randint(0,32,64*64*1)
                ma = np.zeros(patch_size)
                ma[y_inds, x_inds] = 2
                return ma

            def sparse_3set_mask(p=0.02, xs=[1, 2], ys=[]):
                "build random mask for small number of central pixels"
                n = int(np.prod(patch_size) * p)
                x_inds = np.random.randint(0, patch_size[1], n)
                y_inds = np.random.randint(0, patch_size[0], n)
                ma = np.zeros(patch_size)

                # ma = binary_dilation(ma)

                for i in xs:
                    m = x_inds - i >= 0
                    ma[y_inds[m], x_inds[m] - i] = 1
                    m = x_inds + i < patch_size[1]
                    ma[y_inds[m], x_inds[m] + i] = 1

                for i in ys:
                    m = y_inds - i >= 0
                    ma[y_inds[m] - i, x_inds[m]] = 1
                    m = y_inds + i < patch_size[0]
                    ma[y_inds[m] + i, x_inds[m]] = 1

                ma = ma.astype(np.uint8)
                ma[y_inds, x_inds] = 2

                return ma

            def checkerboard_mask():
                ma = np.indices(patch_size).transpose((1, 2, 0))
                ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0
                ma = 2 * ma
                if e % 2 == 1: ma = 2 - ma
                return ma

            ma = sparse_3set_mask(xs=[1, 2]).astype(np.float)
            ma2 = sparse_3set_mask(xs=[1, 2]).astype(np.float)
            # ipdb.set_trace()

            ## apply mask to input
            ma = torch.from_numpy(ma).cuda()
            x1_damaged = x1.clone()
            x1_damaged[:, :, ma > 0] = torch.rand(x1.shape).cuda()[:, :,
                                                                   ma > 0]
            y1p = d.net(x1_damaged)

            ma2 = torch.from_numpy(ma2).cuda()
            y1p_damaged = y1p.clone()
            y1p_damaged[:, :, ma2 > 0] = torch.rand(y1p.shape).cuda()[:, :,
                                                                      ma2 > 0]
            y2p = d.net2(y1p)

            dims = (1, 2, 3)  ## all dims except batch

            tm1 = (ma == 2).float().repeat(4, 1, 1, 1)  ## target mask
            tm2 = (ma2 == 2).float().repeat(4, 1, 1, 1)
            loss_per_patch = (tm1 *
                              torch.abs(y1p - x1)**2).sum(dims) / tm1.sum(dims)
            loss_per_patch += (
                tm2 * torch.abs(y2p - y1p)**2).sum(dims) / tm2.sum(dims)

            lossdist[idxs] = loss_per_patch.detach().cpu()
            loss = loss_per_patch.mean()
            ta.losses.append(float(loss))

            opt.zero_grad()
            opt2.zero_grad()
            loss.backward()
            opt.step()
            opt2.step()

        ## predict on examples and save each epoch

        with torch.no_grad():

            example_yp = d.net(example_xs)
            example_yp2 = d.net2(example_yp)

            yp_fft = torch.fft(
                (example_yp2 - example_yp2.mean())[..., None][..., [0, 0]],
                2).norm(p=2, dim=-1)  #.cpu().detach().numpy()
            yp_fft = torch.from_numpy(
                np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda()
            # yp_fft = yp_fft/yp_fft.max()

            rgb = torch.stack([
                example_xs,
                ma.float().repeat(4, 1, 1, 1) / 2, xs_fft, example_yp2, yp_fft
            ], 0).cpu().detach().numpy()
            arr = rgb.copy()
            # type,samples,channels,y,x
            rgb = normalize3(rgb, axs=(1, 2, 3, 4))
            rgb[[2, 4]] = normalize3(rgb[[2, 4]],
                                     pmin=0,
                                     pmax=99.0,
                                     axs=(1, 2, 3, 4))

            # remove channels and permute
            rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx')
            # arr = collapse2(arr[:,:,0],'tsyx','sy,tx')

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                if e % 10 == 0:
                    io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb)
                if e % 100 == 0:
                    np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr)

        batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size)
        epochs = np.arange(len(ta.losses)) / batches_per_epoch
        plt.clf()
        plt.plot(epochs, ta.losses)
        # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses))
        plt.yscale('log')
        plt.xlabel(f'1 epoch = {batches_per_epoch} batches')
        plt.savefig(d.savedir / f'loss.png', dpi=300)
        if e % 100 == 0:
            torch.save(d.net.state_dict(), savedir / f'models/net{e:03d}.pt')

    pklsave(ta.losses, d.savedir / f'losses.pkl')
    torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt')
    return ta