Exemplo n.º 1
0
    def on_batch_end(self, batch, batch_id, num_batches, logs):
        if self.counter % self.period != 0:
            self.counter += 1
            return

        self.counter = 0
        for batch in self.loader:
            # Get a batch
            batch_v = utils.make_variable(batch, cuda=self._cuda)

            # Forward
            output = self.model(batch_v)

            eps = 1e-8
            mosaic = batch_v["mosaic"].data
            target = batch_v["target"].data
            noise_variance = batch_v["noise_variance"].data
            output = output.data
            target = crop_like(target, output)
            mosaic = crop_like(mosaic, output)

            vizdata = th.cat([mosaic, output, target], 0)
            vizdata = np.clip(vizdata.cpu().numpy(), 0, 1)

            # Display
            self.batch_viz.update(vizdata,
                                  per_row=self.batch_size,
                                  caption="{} | input, ours, reference".format(
                                      self.current_epoch))

            return  # process only one batch
Exemplo n.º 2
0
  def on_batch_end(self, batch, batch_id, num_batches, logs):
    if self.counter % self.period != 0:
      self.counter += 1
      return

    self.counter = 0
    for batch in self.loader:
      # Get a batch
      batch_v = utils.make_variable(batch, cuda=self._cuda)

      # Forward
      output = self.model(batch_v)
      if self.ref is None:
        output_ref = th.zeros_like(output)
      else:
        output_ref = self.ref(batch_v)
        # make sure size match
        output_ref = crop_like(output_ref, output)
        output = crop_like(output, output_ref)


      eps = 1e-8
      mosaic = batch_v["mosaic"].data
      target = batch_v["target"].data
      # noise_variance = batch_v["noise_variance"].data
      output = output.data
      output_ref = output_ref.data
      target = crop_like(target, output)
      mosaic = crop_like(mosaic, output)

      diff = (output-target)

      gdiff = self.grads(diff).abs()
      gdiff = th.nn.functional.pad(gdiff, (0, 1, 0, 1))

      diff = diff.abs()

      gdiff_x = gdiff[:, :3]
      gdiff_y = gdiff[:, 3:]

      gdiff_x = gdiff_x / (target + 1e-4)
      gdiff_y = gdiff_y / (target + 1e-4)

      vizdata = th.cat( [mosaic, output, output_ref, target, diff, gdiff_x, gdiff_y], 0)
      vizdata = np.clip(vizdata.cpu().numpy(), 0, 1)

      psnr = self.psnr(batch_v, output).item()
      psnr_ref = self.psnr(batch_v, output_ref).item()

      # Display
      self.batch_viz.update(
          vizdata, per_row=self.batch_size, 
          caption="{} | input, ours(new) {:.1f} dB, ours(2016) {:.1f} dB, reference, diff, gdiff".format(
            self.current_epoch, psnr, psnr_ref))

      return  # process only one batch
Exemplo n.º 3
0
  def forward(self, samples):
    # 1/4 resolution features
    mosaic = samples["mosaic"]
    gray_mosaic = mosaic.sum(1)
    color_samples = gray_mosaic.unfold(2, 2, 2).unfold(1, 2, 2)
    color_samples = color_samples.permute(0, 3, 4, 1, 2)
    bs, _, _, h, w = color_samples.shape
    color_samples = color_samples.contiguous().view(bs, 4, h, w)

    eps = 1e-8

    color_samples = th.log(color_samples + eps)

    # input_mean = color_samples.mean(1, keepdim=True)
    input_mean = self.local_mean(color_samples)

    # recons_samples = self.net(color_samples)
    recons_samples = self.net(color_samples-input_mean)

    cmean = crop_like(input_mean, recons_samples)

    recons_samples = recons_samples + cmean

    recons_samples = th.exp(recons_samples) - 1e-8

    _, _, h, w = recons_samples.shape

    output = mosaic.new()
    output.resize_(bs, 3, 2*h, 2*w) 
    output.zero_()

    cmosaic = crop_like(mosaic, output)

    # has green
    output[:, 0, ::2, ::2] = recons_samples[:, 0]
    output[:, 1, ::2, ::2] = cmosaic[:, 1, ::2, ::2]
    output[:, 2, ::2, ::2] = recons_samples[:, 1]

    # has red
    output[:, 0, ::2, 1::2] = cmosaic[:, 0, ::2, 1::2]
    output[:, 1, ::2, 1::2] = recons_samples[:, 2]
    output[:, 2, ::2, 1::2] = recons_samples[:, 3]

    # has blue
    output[:, 0, 1::2, 0::2] = recons_samples[:, 4]
    output[:, 1, 1::2, 0::2] = recons_samples[:, 5]
    output[:, 2, 1::2, 0::2] = cmosaic[:, 2, 1::2, 0::2]

    # has green
    output[:, 0, 1::2, 1::2] = recons_samples[:, 6]
    output[:, 1, 1::2, 1::2] = cmosaic[:, 1, 1::2, 1::2]
    output[:, 2, 1::2, 1::2] = recons_samples[:, 7]

    return output
Exemplo n.º 4
0
  def forward(self, samples, kernel_list=None):
    start = time.time()

    self.t += 1

    mosaic = samples["mosaic"]
    mask = samples["mask"]
    gray_mosaic = mosaic.sum(1, keepdim=True)
    bs, _, h, w = gray_mosaic.shape

    x = th.fmod(th.arange(0, w).float().cuda(), self.period).view(1, 1, 1, w).repeat(bs, 1, h, 1)
    y = th.fmod(th.arange(0, h).float().cuda(), self.period).view(1, 1, h, 1).repeat(bs, 1, 1, w)

    color_samples = gray_mosaic.squeeze(1).unfold(2, 2, self.period).unfold(1, 2, self.period)
    color_samples = color_samples.permute(0, 3, 4, 1, 2)
    bs, _, _, h, w = color_samples.shape
    color_samples = color_samples.contiguous().view(bs, self.period**2, h, w)

    weights0 = self.g0(color_samples)
    weights0 = F.softmax(weights0, 1)
    r0 = self.apply_kernels(weights0, self.kernels[0], color_samples)
    r0 = self.upsampler(r0)

    weights1 = self.g1(color_samples)
    weights1 = F.softmax(weights1, 1)
    r1 = self.apply_kernels(weights1, self.kernels[1], color_samples)
    r1 = self.upsampler(r1)

    h, w = r0.shape[-2:]

    green = r0.new()
    green.resize_(bs, 1, h, w)
    green.zero_()

    x = crop_like(x, green)
    y = crop_like(y, green)
    gray_mosaic = crop_like(gray_mosaic, green)

    mask = (x == 1) & (y == 0)
    green[mask] = r0[mask]

    mask = (x == 0) & (y == 1)
    green[mask] = r1[mask]

    green[x == y] = gray_mosaic[x == y] 

    red = th.zeros_like(green)
    blue = th.zeros_like(green)

    output = th.cat([red, green, blue], 1)

    return output
Exemplo n.º 5
0
    def on_batch_end(self, batch, batch_id, num_batches, logs):
        if self.counter % self.period != 0:
            self.counter += 1
            return

        self.counter = 0
        for (batch_id, batch) in enumerate(self.loader):
            # Get a batch
            batch_v = utils.make_variable(batch, cuda=self._cuda)

            # Forward
            output = self.model(batch_v)
            if self.ref is None:
                output_ref = th.zeros_like(output)
            else:
                output_ref = self.ref(batch_v)
                # make sure size match
                output_ref = crop_like(output_ref, output)
                output = crop_like(output, output_ref)

            eps = 1e-8
            mosaic = batch_v["mosaic"].data
            target = batch_v["target"].data
            # noise_variance = batch_v["noise_variance"].data
            output = output.data
            output_ref = output_ref.data
            target = crop_like(target, output)
            mosaic = crop_like(mosaic, output)

            diff = (output - target)

            gdiff = self.grads(diff).abs()
            gdiff = th.nn.functional.pad(gdiff, (0, 1, 0, 1))

            diff = diff.abs()

            gdiff_x = gdiff[:, :3]
            gdiff_y = gdiff[:, 3:]

            gdiff_x = gdiff_x / (target + 1e-4)
            gdiff_y = gdiff_y / (target + 1e-4)

            psnr = self.psnr(batch_v, output).item()
            psnr_ref = self.psnr(batch_v, output_ref).item()

            return  # process only one batch
Exemplo n.º 6
0
 def forward(self, data, output):
     target = crop_like(data["target"], output)
     crop = self.crop
     if crop > 0:
         output = output[..., crop:-crop, crop:-crop]
         target = target[..., crop:-crop, crop:-crop]
     mse = self.mse(output, target) + 1e-12
     return -10 * th.log(mse) / np.log(10)
Exemplo n.º 7
0
 def forward(self, data, output):
   target = crop_like(data["target"], output)
   gradients_tgt = self.grads(target)
   gradients_out = self.grads(output)
   # import torchlib.debug as D
   # D.tensor(gradients_tgt, key="target")
   # D.tensor(gradients_out, key="out")
   # import ipdb; ipdb.set_trace()
   return self.l2(gradients_out, gradients_tgt)
Exemplo n.º 8
0
    def forward(self, data, output):
        target = crop_like(data["target"], output)
        output_f = self.get_features(output)
        with th.no_grad():
            target_f = self.get_features(target)

        losses = []
        for o, t in zip(output_f, target_f):
            losses.append(self.mse(o, t))
        loss = sum(losses)
        if self.weight != 1.0:
            loss = loss * self.weight
        return loss
Exemplo n.º 9
0
  def forward(self, samples):
    # 1/4 resolution features
    mosaic = samples["mosaic"]
    features = self.main_processor(mosaic)

    # crop original mosaic to match output size
    cropped = crop_like(mosaic, features)

    # Concated input samples and residual for further filtering
    packed = th.cat([cropped, features], 1)

    output = self.fullres_processor(packed)

    return output
Exemplo n.º 10
0
  def forward(self, samples):
    # 1/4 resolution features
    mosaic = samples["mosaic"]
    features = self.main_processor(mosaic)
    filters, masks = features[:, :self.width], features[:, self.width:]
    filtered = filters * masks
    residual = self.residual_predictor(filtered)
    upsampled = self.upsampler(residual)

    # crop original mosaic to match output size
    cropped = crop_like(mosaic, upsampled)

    # Concated input samples and residual for further filtering
    packed = th.cat([cropped, upsampled], 1)

    output = self.fullres_processor(packed)

    return output
Exemplo n.º 11
0
  def forward(self, samples, kernel_list=None):
    start = time.time()

    self.t += 1

    # self.temperature *= 1.0001

    mosaic = samples["mosaic"]
    mask = samples["mask"]
    gray_mosaic = mosaic.sum(1, keepdim=True)
    bs, _, h, w = gray_mosaic.shape

    x = th.fmod(th.arange(0, w).float().cuda(), self.period).view(1, 1, 1, w).repeat(bs, 1, h, 1)
    y = th.fmod(th.arange(0, h).float().cuda(), self.period).view(1, 1, h, 1).repeat(bs, 1, 1, w)

    indata = th.cat([mosaic, x, y], 1)
    weights = self.net(indata)

    # print(self.temperature)
    weights = F.softmax(weights, 1)

    recons = self.apply_kernels(weights, gray_mosaic.squeeze(1))

    h, w = recons.shape[-2:]

    green = recons.new()
    green.resize_(bs, 1, h, w)
    green.zero_()

    x = crop_like(x, green)
    y = crop_like(y, green)
    gray_mosaic = crop_like(gray_mosaic, green)

    r0 = recons[:, 0:1]
    r1 = recons[:, 1:2]
    
    mask = (x == 1) & (y == 0)
    green[mask] = r0[mask]

    # mask = (x == 0) & (y == 1)
    # green[mask] = r1[mask]
    #
    # green[x == y] = gray_mosaic[x == y] 

    red = th.zeros_like(green)
    blue = th.zeros_like(green)

    output = th.cat([red, green, blue], 1)

    # if self.t % 100 == 0:
    #   self.wviz.update(self.t, weights.min().item(), name="min")
    #   self.wviz.update(self.t, weights.max().item(), name="max")
    #
    #   kview = self.kernels.view(2*self.nkernels, 1, self.ksize, self.ksize).detach()
    #   mu = kview.mean().item()
    #   std = kview.std().item()
    #   kview = th.clamp(((kview-mu) / (2*std) + 1.0) / 2.0, 0, 1)
    #   self.kviz.update(kview.cpu(), caption="{:.5f} ({:.5f})".format(mu, std))
    #
    #
    #   bs, c, h, w = weights.shape
    #   wview = weights.view(bs*c, 1, h, w).detach()
    #   mu = wview.mean().item()
    #   std = wview.std().item()
    #   wview = th.clamp(((wview-mu) / (2*std) + 1.0) / 2.0, 0, 1)
    #   self.wmapviz.update(wview.cpu(), caption="{:.5f} ({:.5f})".format(mu, std))

    return output
Exemplo n.º 12
0
  def forward(self, samples, kernels=None):
    start = time.time()

    mosaic = samples["mosaic"]
    gray_mosaic = mosaic.sum(1)

    color_samples = gray_mosaic.unfold(2, 2, 2).unfold(1, 2, 2)
    color_samples = color_samples.permute(0, 3, 4, 1, 2)
    bs, _, _, h, w = color_samples.shape
    color_samples = color_samples.contiguous().view(bs, 4, h, w)

    kernels = self.kernels(color_samples)
    bs, _, h, w = kernels.shape

    # th.cuda.synchronize()
    # elapsed = time.time() - start
    # print("Forward {:.0f} ms".format(elapsed*1000))

    # TODO: check what's going on
    g0 = color_samples[:, 0:1]
    b = color_samples[:, 1:2]
    r = color_samples[:, 2:3]
    g1 = color_samples[:, 3:4]

    idx = 0
    ksize = self.ksize

    # Reconstruct 3 reds from known red
    reds = [r]
    for i in range(3):
      k = kernels[:, idx:idx+ksize*ksize]
      k = F.softmax(k, 1)
      idx += ksize*ksize
      reds.append(apply_kernels(k, r))

    # remove unused boundaries
    reds[0] = crop_like(reds[0], reds[1])

    # Reorder 2x2 tile, known red is 0, pattern is:
    # . R  -> 1 0
    # . .     2 3
    reds = [reds[1], reds[0], reds[2], reds[3]]
    red = self.unroll(th.cat(reds, 1))

    # Reconstruct 3 blues from known blue
    blues = [b]
    for i in range(3):
      k = kernels[:, idx:idx+ksize*ksize]
      k = F.softmax(k, 1)
      idx += ksize*ksize
      blues.append(apply_kernels(k, b))

    # remove unused boundaries
    blues[0] = crop_like(blues[0], blues[1])

    # Reorder 2x2 tile, known blue is 0, pattern is:
    # . .  -> 1 2
    # B .     0 3
    blues = [blues[1], blues[2], blues[0], blues[3]]
    blue = self.unroll(th.cat(blues, 1))

    # Reconstruct 2 greens from known greens
    greens = [g0, g1]
    for i in range(2):
      k = kernels[:, idx:idx + 2*ksize*ksize]
      k = F.softmax(k, 1) # jointly normalize the weights

      from_g0 = apply_kernels(k[:, 0:ksize*ksize], g0)
      from_g1 = apply_kernels(k[:, ksize*ksize:2*ksize*ksize], g1)

      greens.append(from_g0+from_g1)

      idx += 2*ksize*ksize

    # remove unused boundaries
    greens[0] = crop_like(greens[0], greens[2])
    greens[1] = crop_like(greens[1], greens[2])

    # Reorder 2x2 tile, known blue is 0, pattern is:
    # G .  -> 0 2
    # . G     3 1
    greens = [greens[0], greens[2], greens[3], greens[1]]
    green = self.unroll(th.cat(greens, 1))

    output = th.cat([red, green, blue], 1)

    # th.cuda.synchronize()
    # elapsed = time.time() - start
    # print("Forward+Apply {:.0f} ms".format(elapsed*1000))

    return output
Exemplo n.º 13
0
 def forward(self, data, output):
     target = crop_like(data["target"], output)
     mse = self.mse(output, target)
     return -10 * th.log(mse) / np.log(10)
Exemplo n.º 14
0
 def forward(self, data, output):
     target = crop_like(data["target"], output)
     return self.mse(output, target) * self.weight
Exemplo n.º 15
0
def main(args, params):
  data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor())
  val_data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor())

  if len(data) == 0:
    log.info("no input files found, aborting.")
    return

  dataloader = DataLoader(data, 
      batch_size=1,
      shuffle=True, num_workers=4)

  val_dataloader = DataLoader(val_data, 
      batch_size=1, shuffle=True, num_workers=0)

  log.info("Training with {} samples".format(len(data)))

  # Starting checkpoint file
  checkpoint = os.path.join(args.output, "checkpoint.ph")
  if args.checkpoint is not None:
    checkpoint = args.checkpoint

  chkpt = None
  if os.path.isfile(checkpoint):
    log.info("Resuming from checkpoint {}".format(checkpoint))
    chkpt = th.load(checkpoint)
    params = chkpt['params']  # override params

  log.info("Model parameters: {}".format(params))

  model = modules.get(params)

  # loss_fn = modules.CharbonnierLoss()
  loss_fn = modules.AlphaLoss()
  optimizer = optim.Adam(model.parameters(), lr=args.lr,
                         weight_decay=args.weight_decay)

  if not os.path.exists(args.output):
    os.makedirs(args.output)

  global_step = 0

  if chkpt is not None:
    model.load_state_dict(chkpt['model_state'])
    optimizer.load_state_dict(chkpt['optimizer'])
    global_step = chkpt['step']

  # Destination checkpoint file
  checkpoint = os.path.join(args.output, "checkpoint.ph")

  name = os.path.basename(args.output)
  loss_viz = viz.ScalarVisualizer("loss", env=name)
  image_viz = viz.BatchVisualizer("images", env=name)
  matte_viz = viz.BatchVisualizer("mattes", env=name)
  weights_viz = viz.BatchVisualizer("weights", env=name)
  trimap_viz = viz.BatchVisualizer("trimap", env=name)

  log.info("Model: {}\n".format(model))

  model.cuda()
  loss_fn.cuda()

  log.info("Starting training from step {}".format(global_step))

  smooth_loss = 0
  smooth_loss_ifm = 0
  smooth_time = 0
  ema_alpha = 0.9
  last_checkpoint_time = time.time()
  try:
    epoch = 0
    while True:
      # Train for one epoch
      for step, batch in enumerate(dataloader):
        batch_start = time.time()
        frac_epoch =  epoch+1.0*step/len(dataloader)

        batch_v = make_variable(batch, cuda=True)

        optimizer.zero_grad()
        output = model(batch_v)
        target = crop_like(batch_v['matte'], output)
        ifm = crop_like(batch_v['vanilla'], output)
        loss = loss_fn(output, target)
        loss_ifm = loss_fn(ifm, target)

        loss.backward()
        # th.nn.utils.clip_grad_norm(model.parameters(), 1e-1)
        optimizer.step()
        global_step += 1

        batch_end = time.time()
        smooth_loss = (1.0-ema_alpha)*loss.data[0] + ema_alpha*smooth_loss
        smooth_loss_ifm = (1.0-ema_alpha)*loss_ifm.data[0] + ema_alpha*smooth_loss_ifm
        smooth_time = (1.0-ema_alpha)*(batch_end-batch_start) + ema_alpha*smooth_time

        if global_step % args.log_step == 0:
          log.info("Epoch {:.1f} | loss = {:.7f} | {:.1f} samples/s".format(
            frac_epoch, smooth_loss, target.shape[0]/smooth_time))

        if args.viz_step > 0 and global_step % args.viz_step == 0:
          model.train(False)
          for val_batch in val_dataloader:
            val_batchv = make_variable(val_batch, cuda=True)
            output = model(val_batchv)
            target = crop_like(val_batchv['matte'], output)
            vanilla = crop_like(val_batchv['vanilla'], output)
            val_loss = loss_fn(output, target)

            mini, maxi = target.min(), target.max()

            diff = (th.abs(output-target))
            vizdata = th.cat((target, output, vanilla, diff), 0)
            vizdata = (vizdata-mini)/(maxi-mini)
            imgs = np.power(np.clip(vizdata.cpu().data, 0, 1), 1.0/2.2)

            image_viz.update(val_batchv['image'].cpu().data, per_row=1)
            trimap_viz.update(val_batchv['trimap'].cpu().data, per_row=1)
            weights = model.predicted_weights.permute(1, 0, 2, 3)
            new_w = []
            means = []
            var = []
            for ii in range(weights.shape[0]):
              w = weights[ii:ii+1, ...]
              mu = w.mean()
              sigma = w.std()
              new_w.append(0.5*((w-mu)/(2*sigma)+1.0))
              means.append(mu.data.cpu()[0])
              var.append(sigma.data.cpu()[0])
            weights = th.cat(new_w, 0)
            weights = th.clamp(weights, 0, 1)
            weights_viz.update(weights.cpu().data,
                caption="CM {:.4f} ({:.4f})| LOC {:.4f} ({:.4f}) | IU {:.4f} ({:.4f}) | KU {:.4f} ({:.4f})".format(
                  means[0], var[0],
                  means[1], var[1],
                  means[2], var[2],
                  means[3], var[3]), per_row=4)
            matte_viz.update(
                imgs,
                caption="Epoch {:.1f} | loss = {:.6f} | target, output, vanilla, diff".format(
                  frac_epoch, val_loss.data[0]), per_row=4)
            log.info("  viz at step {}, loss = {:.6f}".format(global_step, val_loss.cpu().data[0]))
            break  # Only one batch for validation

          losses = [smooth_loss, smooth_loss_ifm]
          legend = ["ours", "ref_ifm"]
          loss_viz.update(frac_epoch, losses, legend=legend)

          model.train(True)

        if batch_end-last_checkpoint_time > args.checkpoint_interval:
          last_checkpoint_time = time.time()
          save(checkpoint, model, params, optimizer, global_step)


      epoch += 1
      if args.epochs > 0 and epoch >= args.epochs:
        log.info("Ending training at epoch {} of {}".format(epoch, args.epochs))
        break

  except KeyboardInterrupt:
    log.info("training interrupted at step {}".format(global_step))
    checkpoint = os.path.join(args.output, "on_stop.ph")
    save(checkpoint, model, params, optimizer, global_step)
Exemplo n.º 16
0
    def forward(self, data, output):
        target = crop_like(data["target"], output)
        gradients_tgt = self.grads(target)
        gradients_out = self.grads(output)

        return self.l2(gradients_out, gradients_tgt)