Exemplo n.º 1
0
def main(args):
    linearize = False
    if args.xtrans:
        data = dset.XtransDataset(args.data_dir,
                                  transform=None,
                                  augment=True,
                                  linearize=linearize)
    else:
        data = dset.BayerDataset(args.data_dir,
                                 transform=None,
                                 augment=True,
                                 linearize=linearize)
    loader = DataLoader(data,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=8)

    if args.xtrans:
        period = 6
    else:
        period = 2

    mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect")
    mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect")
    diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect")
    target_viz = viz.BatchVisualizer("target", env="demosaic_inspect")
    input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect")

    for sample in loader:
        mosaic = sample["mosaic"]
        mask = sample["mask"]
        target = sample["target"]

        # for c in [0, 2]:
        #   target[:, c] = 0
        #   mosaic[:, c] = 0

        diff = target - mosaic

        mos_mask = th.cat([mosaic, mask], 1)
        mos_mask = mos_mask.unfold(2, 2 * period,
                                   2 * period).unfold(3, 2 * period,
                                                      2 * period)
        bs, c, h, w, kh, kw = mos_mask.shape
        mos_mask = mos_mask.permute(0, 2, 3, 1, 4, 5).contiguous().view(
            bs * h * w, c, kh * kw)

        import ipdb
        ipdb.set_trace()

        mask_viz.update(mask)
        mos_viz.update(mosaic)
        diff_viz.update(diff)
        target_viz.update(target)
        input_hist.update(target[:, 1].contiguous().view(-1).numpy())

        import ipdb
        ipdb.set_trace()
Exemplo n.º 2
0
def main(args):
    linearize = False
    if args.xtrans:
        period = 6
        data = dset.XtransDataset(args.data_dir,
                                  transform=None,
                                  augment=False,
                                  linearize=linearize)
    else:
        period = 2
        data = dset.BayerDataset(args.data_dir,
                                 transform=None,
                                 augment=False,
                                 linearize=linearize)
    loader = DataLoader(data,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=8)

    mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect")
    mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect")
    diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect")
    target_viz = viz.BatchVisualizer("target", env="demosaic_inspect")
    input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect")

    for sample in loader:
        mosaic = sample["mosaic"]
        mask = sample["mask"]

        pad = args.ksize // 2
        dx = (pad - args.offset_x) % period
        dy = (pad - args.offset_y) % period
        print("dx {} dy {}".format(dx, dy))
        mosaic = mosaic[..., dy:, dx:]
        mask = mask[..., dy:, dx:]

        def to_patches(arr):
            patches = arr.unfold(2, args.ksize,
                                 period).unfold(3, args.ksize, period)
            bs, c, h, w, _, _ = patches.shape
            patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
            patches = patches.view(bs * h * w, c, args.ksize, args.ksize)
            return patches

        patches = to_patches(mosaic)
        bs = patches.shape[0]
        means = patches.view(bs, -1).mean(-1).view(bs, 1, 1, 1)
        std = patches.view(bs, -1).std(-1).view(bs, 1, 1, 1)
        print(means.min().item(), means.max().item())

        patches -= means
        patches /= std + 1e-8

        new_bs = 1024
        idx = np.random.randint(0, patches.shape[0], (new_bs, ))
        patches = patches[idx]

        import torchlib.debug as D
        D.tensor(patches)

        flat = patches.view(new_bs, -1).cpu().numpy()

        nclusts = 16
        clst = cluster.MiniBatchKMeans(n_clusters=nclusts)
        # clst.fit(flat)
        clst_idx = clst.fit_predict(flat)
        colors = np.random.uniform(size=(nclusts, 3))

        manif = manifold.TSNE(n_components=2)
        new_coords = manif.fit_transform(flat)
        color = np.zeros((new_coords.shape[0], 3))
        color = (colors[clst_idx, :] * 255).astype(np.uint8)
        print(color.shape)
        D.scatter(th.from_numpy(new_coords[:, 0]),
                  th.from_numpy(new_coords[:, 1]),
                  color=color,
                  key="tsne")

        centers = th.from_numpy(clst.cluster_centers_).view(
            nclusts, 3, args.ksize, args.ksize)
        D.tensor(centers, "centers")

        for cidx in range(nclusts):
            idx = clst_idx == cidx
            p = th.from_numpy(patches.numpy()[idx])
            D.tensor(p, key="cluster_{:02d}".format(cidx))

        import ipdb
        ipdb.set_trace()
Exemplo n.º 3
0
def main(args, model_params):
    if args.fix_seed:
        np.random.seed(0)
        th.manual_seed(0)

    # ------------ Set up datasets ----------------------------------------------
    xforms = [dset.ToTensor()]
    if args.green_only:
        xforms.append(dset.GreenOnly())
    xforms = transforms.Compose(xforms)
    if args.xtrans:
        data = dset.XtransDataset(args.data_dir,
                                  transform=xforms,
                                  augment=True,
                                  linearize=args.linear)
    else:
        data = dset.BayerDataset(args.data_dir,
                                 transform=xforms,
                                 augment=True,
                                 linearize=args.linear)
    data[0]

    if args.val_data is not None:
        if args.xtrans:
            val_data = dset.XtransDataset(args.val_data,
                                          transform=xforms,
                                          augment=False)
        else:
            val_data = dset.BayerDataset(args.val_data,
                                         transform=xforms,
                                         augment=False)
    else:
        val_data = None
    # ---------------------------------------------------------------------------

    model = modules.get(model_params)
    log.info("Model configuration: {}".format(model_params))

    if args.pretrained:
        log.info("Loading Caffe weights")
        if args.xtrans:
            model_ref = modules.get({"model": "XtransNetwork"})
            cvt = converter.Converter(args.pretrained, "XtransNetwork")
        else:
            model_ref = modules.get({"model": "BayerNetwork"})
            cvt = converter.Converter(args.pretrained, "BayerNetwork")
        cvt.convert(model_ref)
        model_ref.cuda()
    else:
        model_ref = None

    if args.green_only:
        model = modules.GreenOnly(model)
        model_ref = modules.GreenOnly(model_ref)

    if args.subsample:
        dx = 1
        dy = 0
        if args.xtrans:
            period = 6
        else:
            period = 2
        model = modules.Subsample(model, period, dx=dx, dy=dy)
        model_ref = modules.Subsample(model_ref, period, dx=dx, dy=dy)

    if args.linear:
        model = modules.DeLinearize(model)
        model_ref = modules.DeLinearize(model_ref)

    name = os.path.basename(args.output)
    cbacks = [
        default_callbacks.LossCallback(env=name),
        callbacks.DemosaicVizCallback(val_data,
                                      model,
                                      model_ref,
                                      cuda=True,
                                      shuffle=False,
                                      env=name),
        callbacks.PSNRCallback(env=name),
    ]

    metrics = {"psnr": losses.PSNR(crop=4)}

    log.info("Using {} loss".format(args.loss))
    if args.loss == "l2":
        criteria = {
            "l2": losses.L2Loss(),
        }
    elif args.loss == "l1":
        criteria = {
            "l1": losses.L1Loss(),
        }
    elif args.loss == "gradient":
        criteria = {
            "gradient": losses.GradientLoss(),
        }
    elif args.loss == "laplacian":
        criteria = {
            "laplacian": losses.LaplacianLoss(),
        }
    elif args.loss == "vgg":
        criteria = {
            "vgg": losses.VGGLoss(),
        }
    else:
        raise ValueError("not implemented")

    optimizer = optim.Adam
    optimizer_params = {}
    if args.optimizer == "sgd":
        optimizer = optim.SGD
        optimizer_params = {"momentum": 0.9}
    train_params = Trainer.Parameters(viz_step=100,
                                      lr=args.lr,
                                      batch_size=args.batch_size,
                                      optimizer=optimizer,
                                      optimizer_params=optimizer_params)

    trainer = Trainer(data,
                      model,
                      criteria,
                      output=args.output,
                      params=train_params,
                      model_params=model_params,
                      verbose=args.debug,
                      callbacks=cbacks,
                      metrics=metrics,
                      valset=val_data,
                      cuda=True)

    trainer.train()
Exemplo n.º 4
0
def main(args, model_params):
    if args.fix_seed:
        np.random.seed(0)
        th.manual_seed(0)

    # ------------ Set up datasets ----------------------------------------------
    xforms = dset.ToTensor()
    data = dset.BayerDataset(args.data_dir, transform=xforms, augment=True)
    data[0]

    if args.val_data is not None:
        val_data = dset.BayerDataset(args.val_data,
                                     transform=xforms,
                                     augment=False)
    else:
        val_data = None
    # ---------------------------------------------------------------------------

    model = modules.get(model_params)
    log.info("Model configuration: {}".format(model_params))

    if args.pretrained:
        log.info("Loading Caffe weights")
        cvt = converter.Converter(args.pretrained, model_params["model"])
        cvt.convert(model)

    name = os.path.basename(args.output)
    cbacks = [
        default_callbacks.LossCallback(env=name),
        callbacks.DemosaicVizCallback(val_data,
                                      model,
                                      cuda=True,
                                      shuffle=True,
                                      env=name),
    ]

    metrics = {"psnr": losses.PSNR()}

    log.info("Using {} loss".format(args.loss))
    if args.loss == "l2":
        criteria = {
            "l2": losses.L2Loss(),
        }
    elif args.loss == "vgg":
        criteria = {
            "vgg": losses.VGGLoss(),
        }
    else:
        raise ValueError("not implemented")

    train_params = Trainer.Parameters(viz_step=100,
                                      lr=args.lr,
                                      batch_size=args.batch_size)
    # optimizer=toptim.SVAG)

    trainer = Trainer(data,
                      model,
                      criteria,
                      output=args.output,
                      params=train_params,
                      model_params=model_params,
                      verbose=args.debug,
                      callbacks=cbacks,
                      metrics=metrics,
                      valset=val_data,
                      cuda=True)

    trainer.train()