def main(args):
    linearize = False
    if args.xtrans:
        data = dset.XtransDataset(args.data_dir,
                                  transform=None,
                                  augment=True,
                                  linearize=linearize)
    else:
        data = dset.BayerDataset(args.data_dir,
                                 transform=None,
                                 augment=True,
                                 linearize=linearize)
    loader = DataLoader(data,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=8)

    if args.xtrans:
        period = 6
    else:
        period = 2

    mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect")
    mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect")
    diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect")
    target_viz = viz.BatchVisualizer("target", env="demosaic_inspect")
    input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect")

    for sample in loader:
        mosaic = sample["mosaic"]
        mask = sample["mask"]
        target = sample["target"]

        # for c in [0, 2]:
        #   target[:, c] = 0
        #   mosaic[:, c] = 0

        diff = target - mosaic

        mos_mask = th.cat([mosaic, mask], 1)
        mos_mask = mos_mask.unfold(2, 2 * period,
                                   2 * period).unfold(3, 2 * period,
                                                      2 * period)
        bs, c, h, w, kh, kw = mos_mask.shape
        mos_mask = mos_mask.permute(0, 2, 3, 1, 4, 5).contiguous().view(
            bs * h * w, c, kh * kw)

        import ipdb
        ipdb.set_trace()

        mask_viz.update(mask)
        mos_viz.update(mosaic)
        diff_viz.update(diff)
        target_viz.update(target)
        input_hist.update(target[:, 1].contiguous().view(-1).numpy())

        import ipdb
        ipdb.set_trace()
Exemple #2
0
  def __init__(self, fov=7):
    super(BayerLog, self).__init__()

    self.fov = fov

    self.net = nn.Sequential(
      nn.Conv2d(4, 16, 3),
      nn.LeakyReLU(inplace=True),
      nn.Conv2d(16, 16, 3),
      nn.LeakyReLU(inplace=True),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(16, 32, 3),
      nn.LeakyReLU(inplace=True),
      nn.Conv2d(32, 3, 3),
      )

    self.debug_viz = viz.BatchVisualizer("batch", env="mosaic_debug")
    self.debug_viz2 = viz.BatchVisualizer("batch2", env="mosaic_debug")
    self.debug = False
Exemple #3
0
  def __init__(self, model, reference, num_batches, val_loader, env=None):
    self.model = model
    self.reference = reference
    self.num_batches = num_batches
    self.val_loader = val_loader

    self.viz = viz.BatchVisualizer("demosaick", env=env)

    self.loss_viz = viz.ScalarVisualizer(
        "loss", opts={"legend": ["train", "val"]}, env=env)
    self.psnr_viz = viz.ScalarVisualizer(
        "psnr", opts={"legend": ["train", "train_g", "val"]}, env=env)
    self.ssim_viz = viz.ScalarVisualizer(
        "1-ssim", opts={"legend": ["train", "val"]}, env=env)
    self.l1_viz = viz.ScalarVisualizer(
        "l1", opts={"legend": ["train", "val"]}, env=env)

    self.current_epoch = 0
  def __init__(self, data, model, ref, env=None, batch_size=16, 
               shuffle=False, cuda=True, period=500):
    super(DemosaicVizCallback, self).__init__()
    self.batch_size = batch_size
    self.model = model
    self.ref = ref
    self.batch_viz = viz.BatchVisualizer("batch", env=env)
    self._cuda = cuda

    self.loader = DataLoader(
        data, batch_size=batch_size,
        shuffle=shuffle, num_workers=0, drop_last=True)

    self.period = period
    self.counter = 0

    self.psnr = losses.PSNR(crop=8)

    self.grads = ImageGradients(3).cuda()
    def __init__(self,
                 data,
                 model,
                 env=None,
                 batch_size=8,
                 shuffle=False,
                 cuda=True,
                 period=100):
        super(DemosaicVizCallback, self).__init__()
        self.batch_size = batch_size
        self.model = model
        self.batch_viz = viz.BatchVisualizer("batch", env=env)
        self._cuda = cuda

        self.loader = DataLoader(data,
                                 batch_size=batch_size,
                                 shuffle=shuffle,
                                 num_workers=0,
                                 drop_last=True)

        self.period = period
        self.counter = 0
Exemple #6
0
    def __init__(self, model, ref_model, val_loader, cuda, env=None):
        self.model = model
        self.ref_model = ref_model
        self.val_loader = val_loader
        self.cuda = cuda

        self.viz = viz.BatchVisualizer("denoise", port=args.port, env=env)

        self.loss_viz = viz.ScalarVisualizer("loss", port=args.port, env=env)
        self.psnr_viz = viz.ScalarVisualizer("psnr", port=args.port, env=env)
        self.val_loss_viz = viz.ScalarVisualizer("val_loss",
                                                 port=args.port,
                                                 env=env)
        self.val_psnr_viz = viz.ScalarVisualizer("val_psnr",
                                                 port=args.port,
                                                 env=env)
        self.ref_loss_viz = viz.ScalarVisualizer("ref_loss",
                                                 port=args.port,
                                                 env=env)
        self.ref_psnr_viz = viz.ScalarVisualizer("ref_psnr",
                                                 port=args.port,
                                                 env=env)
    def __init__(self, model, ref_model, val_loader, cuda, env=None):
        self.model = model
        self.ref_model = ref_model
        self.val_loader = val_loader
        self.cuda = cuda

        self.viz = viz.BatchVisualizer("deconv", port=args.port, env=env)
        self.psf_viz = viz.BatchVisualizer("psf", port=args.port, env=env)
        self.data_kernels0_viz = viz.BatchVisualizer("data_kernels0",
                                                     port=args.port,
                                                     env=env)
        self.data_kernels1_viz = viz.BatchVisualizer("data_kernels1",
                                                     port=args.port,
                                                     env=env)
        self.data_kernel_weights0_viz = viz.ScalarVisualizer(
            "data_kernel_weights0",
            ntraces=self.model.data_kernel_weights.shape[1],
            port=args.port,
            env=env)
        self.data_kernel_weights1_viz = viz.ScalarVisualizer(
            "data_kernel_weights1",
            ntraces=self.model.data_kernel_weights.shape[1],
            port=args.port,
            env=env)
        self.reg_kernels0_viz = viz.BatchVisualizer("reg_kernels0",
                                                    port=args.port,
                                                    env=env)
        self.reg_kernels1_viz = viz.BatchVisualizer("reg_kernels1",
                                                    port=args.port,
                                                    env=env)
        self.reg_kernel_weights0_viz = viz.ScalarVisualizer(
            "reg_kernel_weights0",
            ntraces=self.model.reg_kernel_weights.shape[1],
            port=args.port,
            env=env)
        self.reg_kernel_weights1_viz = viz.ScalarVisualizer(
            "reg_kernel_weights1",
            ntraces=self.model.reg_kernel_weights.shape[1],
            port=args.port,
            env=env)
        self.filter_s_viz = viz.ScalarVisualizer(
            "filter_s0",
            ntraces=self.model.filter_s.shape[1],
            port=args.port,
            env=env)
        self.filter_r_viz = viz.ScalarVisualizer(
            "filter_r0",
            ntraces=self.model.filter_r.shape[1],
            port=args.port,
            env=env)
        self.reg_thresholds_viz = viz.ScalarVisualizer(
            "reg_thresholds",
            ntraces=self.model.reg_thresholds.shape[1],
            port=args.port,
            env=env)

        self.loss_viz = viz.ScalarVisualizer("loss", port=args.port, env=env)
        self.psnr_viz = viz.ScalarVisualizer("psnr", port=args.port, env=env)
        self.val_loss_viz = viz.ScalarVisualizer("val_loss",
                                                 port=args.port,
                                                 env=env)
        self.val_psnr_viz = viz.ScalarVisualizer("val_psnr",
                                                 port=args.port,
                                                 env=env)
        self.ref_loss_viz = viz.ScalarVisualizer("ref_loss",
                                                 port=args.port,
                                                 env=env)
        self.ref_psnr_viz = viz.ScalarVisualizer("ref_psnr",
                                                 port=args.port,
                                                 env=env)
Exemple #8
0
def main(args, params):
  data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor())
  val_data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor())

  if len(data) == 0:
    log.info("no input files found, aborting.")
    return

  dataloader = DataLoader(data, 
      batch_size=1,
      shuffle=True, num_workers=4)

  val_dataloader = DataLoader(val_data, 
      batch_size=1, shuffle=True, num_workers=0)

  log.info("Training with {} samples".format(len(data)))

  # Starting checkpoint file
  checkpoint = os.path.join(args.output, "checkpoint.ph")
  if args.checkpoint is not None:
    checkpoint = args.checkpoint

  chkpt = None
  if os.path.isfile(checkpoint):
    log.info("Resuming from checkpoint {}".format(checkpoint))
    chkpt = th.load(checkpoint)
    params = chkpt['params']  # override params

  log.info("Model parameters: {}".format(params))

  model = modules.get(params)

  # loss_fn = modules.CharbonnierLoss()
  loss_fn = modules.AlphaLoss()
  optimizer = optim.Adam(model.parameters(), lr=args.lr,
                         weight_decay=args.weight_decay)

  if not os.path.exists(args.output):
    os.makedirs(args.output)

  global_step = 0

  if chkpt is not None:
    model.load_state_dict(chkpt['model_state'])
    optimizer.load_state_dict(chkpt['optimizer'])
    global_step = chkpt['step']

  # Destination checkpoint file
  checkpoint = os.path.join(args.output, "checkpoint.ph")

  name = os.path.basename(args.output)
  loss_viz = viz.ScalarVisualizer("loss", env=name)
  image_viz = viz.BatchVisualizer("images", env=name)
  matte_viz = viz.BatchVisualizer("mattes", env=name)
  weights_viz = viz.BatchVisualizer("weights", env=name)
  trimap_viz = viz.BatchVisualizer("trimap", env=name)

  log.info("Model: {}\n".format(model))

  model.cuda()
  loss_fn.cuda()

  log.info("Starting training from step {}".format(global_step))

  smooth_loss = 0
  smooth_loss_ifm = 0
  smooth_time = 0
  ema_alpha = 0.9
  last_checkpoint_time = time.time()
  try:
    epoch = 0
    while True:
      # Train for one epoch
      for step, batch in enumerate(dataloader):
        batch_start = time.time()
        frac_epoch =  epoch+1.0*step/len(dataloader)

        batch_v = make_variable(batch, cuda=True)

        optimizer.zero_grad()
        output = model(batch_v)
        target = crop_like(batch_v['matte'], output)
        ifm = crop_like(batch_v['vanilla'], output)
        loss = loss_fn(output, target)
        loss_ifm = loss_fn(ifm, target)

        loss.backward()
        # th.nn.utils.clip_grad_norm(model.parameters(), 1e-1)
        optimizer.step()
        global_step += 1

        batch_end = time.time()
        smooth_loss = (1.0-ema_alpha)*loss.data[0] + ema_alpha*smooth_loss
        smooth_loss_ifm = (1.0-ema_alpha)*loss_ifm.data[0] + ema_alpha*smooth_loss_ifm
        smooth_time = (1.0-ema_alpha)*(batch_end-batch_start) + ema_alpha*smooth_time

        if global_step % args.log_step == 0:
          log.info("Epoch {:.1f} | loss = {:.7f} | {:.1f} samples/s".format(
            frac_epoch, smooth_loss, target.shape[0]/smooth_time))

        if args.viz_step > 0 and global_step % args.viz_step == 0:
          model.train(False)
          for val_batch in val_dataloader:
            val_batchv = make_variable(val_batch, cuda=True)
            output = model(val_batchv)
            target = crop_like(val_batchv['matte'], output)
            vanilla = crop_like(val_batchv['vanilla'], output)
            val_loss = loss_fn(output, target)

            mini, maxi = target.min(), target.max()

            diff = (th.abs(output-target))
            vizdata = th.cat((target, output, vanilla, diff), 0)
            vizdata = (vizdata-mini)/(maxi-mini)
            imgs = np.power(np.clip(vizdata.cpu().data, 0, 1), 1.0/2.2)

            image_viz.update(val_batchv['image'].cpu().data, per_row=1)
            trimap_viz.update(val_batchv['trimap'].cpu().data, per_row=1)
            weights = model.predicted_weights.permute(1, 0, 2, 3)
            new_w = []
            means = []
            var = []
            for ii in range(weights.shape[0]):
              w = weights[ii:ii+1, ...]
              mu = w.mean()
              sigma = w.std()
              new_w.append(0.5*((w-mu)/(2*sigma)+1.0))
              means.append(mu.data.cpu()[0])
              var.append(sigma.data.cpu()[0])
            weights = th.cat(new_w, 0)
            weights = th.clamp(weights, 0, 1)
            weights_viz.update(weights.cpu().data,
                caption="CM {:.4f} ({:.4f})| LOC {:.4f} ({:.4f}) | IU {:.4f} ({:.4f}) | KU {:.4f} ({:.4f})".format(
                  means[0], var[0],
                  means[1], var[1],
                  means[2], var[2],
                  means[3], var[3]), per_row=4)
            matte_viz.update(
                imgs,
                caption="Epoch {:.1f} | loss = {:.6f} | target, output, vanilla, diff".format(
                  frac_epoch, val_loss.data[0]), per_row=4)
            log.info("  viz at step {}, loss = {:.6f}".format(global_step, val_loss.cpu().data[0]))
            break  # Only one batch for validation

          losses = [smooth_loss, smooth_loss_ifm]
          legend = ["ours", "ref_ifm"]
          loss_viz.update(frac_epoch, losses, legend=legend)

          model.train(True)

        if batch_end-last_checkpoint_time > args.checkpoint_interval:
          last_checkpoint_time = time.time()
          save(checkpoint, model, params, optimizer, global_step)


      epoch += 1
      if args.epochs > 0 and epoch >= args.epochs:
        log.info("Ending training at epoch {} of {}".format(epoch, args.epochs))
        break

  except KeyboardInterrupt:
    log.info("training interrupted at step {}".format(global_step))
    checkpoint = os.path.join(args.output, "on_stop.ph")
    save(checkpoint, model, params, optimizer, global_step)
Exemple #9
0
def main(args):
    linearize = False
    if args.xtrans:
        period = 6
        data = dset.XtransDataset(args.data_dir,
                                  transform=None,
                                  augment=False,
                                  linearize=linearize)
    else:
        period = 2
        data = dset.BayerDataset(args.data_dir,
                                 transform=None,
                                 augment=False,
                                 linearize=linearize)
    loader = DataLoader(data,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=8)

    mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect")
    mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect")
    diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect")
    target_viz = viz.BatchVisualizer("target", env="demosaic_inspect")
    input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect")

    for sample in loader:
        mosaic = sample["mosaic"]
        mask = sample["mask"]

        pad = args.ksize // 2
        dx = (pad - args.offset_x) % period
        dy = (pad - args.offset_y) % period
        print("dx {} dy {}".format(dx, dy))
        mosaic = mosaic[..., dy:, dx:]
        mask = mask[..., dy:, dx:]

        def to_patches(arr):
            patches = arr.unfold(2, args.ksize,
                                 period).unfold(3, args.ksize, period)
            bs, c, h, w, _, _ = patches.shape
            patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
            patches = patches.view(bs * h * w, c, args.ksize, args.ksize)
            return patches

        patches = to_patches(mosaic)
        bs = patches.shape[0]
        means = patches.view(bs, -1).mean(-1).view(bs, 1, 1, 1)
        std = patches.view(bs, -1).std(-1).view(bs, 1, 1, 1)
        print(means.min().item(), means.max().item())

        patches -= means
        patches /= std + 1e-8

        new_bs = 1024
        idx = np.random.randint(0, patches.shape[0], (new_bs, ))
        patches = patches[idx]

        import torchlib.debug as D
        D.tensor(patches)

        flat = patches.view(new_bs, -1).cpu().numpy()

        nclusts = 16
        clst = cluster.MiniBatchKMeans(n_clusters=nclusts)
        # clst.fit(flat)
        clst_idx = clst.fit_predict(flat)
        colors = np.random.uniform(size=(nclusts, 3))

        manif = manifold.TSNE(n_components=2)
        new_coords = manif.fit_transform(flat)
        color = np.zeros((new_coords.shape[0], 3))
        color = (colors[clst_idx, :] * 255).astype(np.uint8)
        print(color.shape)
        D.scatter(th.from_numpy(new_coords[:, 0]),
                  th.from_numpy(new_coords[:, 1]),
                  color=color,
                  key="tsne")

        centers = th.from_numpy(clst.cluster_centers_).view(
            nclusts, 3, args.ksize, args.ksize)
        D.tensor(centers, "centers")

        for cidx in range(nclusts):
            idx = clst_idx == cidx
            p = th.from_numpy(patches.numpy()[idx])
            D.tensor(p, key="cluster_{:02d}".format(cidx))

        import ipdb
        ipdb.set_trace()