Exemplo n.º 1
0
def setup_flower_shutter(
    rawdata='/lustre/projects/project-broaddus/rawdata/artifacts/flower.tif',
    savedir='/lustre/projects/project-broaddus/denoise_experiments/flower/e01/flower_test'
):

    savedir = Path(savedir)
    wipe_dirs(savedir)
    init_dirs(savedir)

    img = imread(rawdata)
    pmin, pmax = 2, 99.6
    print(f"pmin = {pmin}; pmax = {pmax}")
    img = normalize3(img, pmin, pmax).astype(np.float32, copy=False)
    data = img.reshape((-1, 4, 256, 4, 256)).transpose(
        (0, 1, 3, 2, 4)).reshape((-1, 1, 256, 256))

    d = SimpleNamespace()
    d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU)
    d.net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt'
        ))

    d.net.apply(init_weights)
    d.savedir = savedir

    d.x1_all = torch.from_numpy(data).float()
    return d
Exemplo n.º 2
0
def predict_full():
    "make movies scrolling through z"

    net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda()
    # Rob Jenkin (Alana) 540 692 0113
    net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise/flower/e01/flower3_6/models/net600.pt'
        ))
    img = imread(
        f'/lustre/projects/project-broaddus/devseg_data/raw/artifacts/flower.tif'
    )
    # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8)
    pmin, pmax = 2, 99.6
    img = normalize3(img, pmin, pmax, axs=(1, 2)).astype(np.float32,
                                                         copy=False)
    pimg = []
    for x in img:
        # x = torch.from_numpy(x).cuda()
        # x = net(x[None])
        x = apply_net_tiled(net, x[None])
        pimg.append(x)
    pimg = np.array(pimg)
    # return img, net, pimg
    # pimg = apply_net_tiled(net,img[:,None])
    imsave(pimg, savedir / f'pred_flower.tif')
Exemplo n.º 3
0
def setup_flower_shutter(rawdata, savedir):

    savedir = Path(savedir)
    wipe_dirs(savedir)
    init_dirs(savedir)

    data = imread(rawdata)
    data = normalize3(data, 2, 99.6)  ## normalize across all dims?

    d = SimpleNamespace()
    d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU)
    d.net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt'
        ))

    # d.net.apply(init_weights);
    d.savedir = savedir

    d.xs = torch.from_numpy(data).float()
    d.xs = d.xs.reshape(100, 4, 256, 4, 256,
                        1).permute(0, 1, 3, 5, 2, 4)  #.reshape((-1,256,256))
    d.ys = d.xs.mean(0)

    io.imsave(savedir / 'xs.png',
              collapse2(d.xs[0, :, :, 0].numpy(), "12yx", "1y,2x"))
    io.imsave(savedir / 'ys.png',
              collapse2(d.ys[:, :, 0].numpy(), "12yx", "1y,2x"))
    d.cuda = False

    return d
Exemplo n.º 4
0
def predict_on_full_2d_stack(rawdata, savedir, weights):
    savedir = Path(savedir)
    net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda()
    net.load_state_dict(torch.load(weights))
    img = imread(rawdata)
    pmin, pmax = 2, 99.6
    img = normalize3(img, pmin, pmax, axs=(1, 2)).astype(np.float32,
                                                         copy=False)
    pimg = []
    for x in img:
        x = apply_net_tiled_2d(net, x[None])
        pimg.append(x)
    pimg = np.array(pimg)
    imsave(pimg.astype(np.float16), savedir / f'pred.tif', compress=9)
Exemplo n.º 5
0
def setup(params={}):

    wipe_dirs(savedir)
    init_dirs(savedir)

    # dg = datagen(savedir=savedir); data = dg.data;

    # data = np.load('/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/data_shutter.npz')['data']
    data = np.load(
        '/lustre/projects/project-broaddus/denoise/flower/e01/data_flower3.npz'
    )['data']
    # data = np.load('/lustre/projects/project-broaddus/denoise/flower/e02/data_flower.npz')['data']

    d = SimpleNamespace()

    d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda()
    d.net.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise/flower/models/net_randinit.pt'
        ))
    # d.net.apply(init_weights);

    d.net2 = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda()
    d.net2.load_state_dict(
        torch.load(
            '/lustre/projects/project-broaddus/denoise/flower/models/net_randinit.pt'
        ))
    # d.net2.apply(init_weights);

    d.savedir = savedir

    # d.net.load_state_dict(torch.load('/lustre/projects/project-broaddus/devseg_data/cl_datagen/d000/jj000/net250.pt'))
    # torch.save(d.net.state_dict(), '/lustre/projects/project-broaddus/devseg_data/cl_datagen/rsrc/net_random_init_unet2.pt')

    d.x1_all = torch.from_numpy(data).float().cuda()
    return d