def main(args): linearize = False if args.xtrans: data = dset.XtransDataset(args.data_dir, transform=None, augment=True, linearize=linearize) else: data = dset.BayerDataset(args.data_dir, transform=None, augment=True, linearize=linearize) loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, num_workers=8) if args.xtrans: period = 6 else: period = 2 mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect") mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect") diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect") target_viz = viz.BatchVisualizer("target", env="demosaic_inspect") input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect") for sample in loader: mosaic = sample["mosaic"] mask = sample["mask"] target = sample["target"] # for c in [0, 2]: # target[:, c] = 0 # mosaic[:, c] = 0 diff = target - mosaic mos_mask = th.cat([mosaic, mask], 1) mos_mask = mos_mask.unfold(2, 2 * period, 2 * period).unfold(3, 2 * period, 2 * period) bs, c, h, w, kh, kw = mos_mask.shape mos_mask = mos_mask.permute(0, 2, 3, 1, 4, 5).contiguous().view( bs * h * w, c, kh * kw) import ipdb ipdb.set_trace() mask_viz.update(mask) mos_viz.update(mosaic) diff_viz.update(diff) target_viz.update(target) input_hist.update(target[:, 1].contiguous().view(-1).numpy()) import ipdb ipdb.set_trace()
def __init__(self, fov=7): super(BayerLog, self).__init__() self.fov = fov self.net = nn.Sequential( nn.Conv2d(4, 16, 3), nn.LeakyReLU(inplace=True), nn.Conv2d(16, 16, 3), nn.LeakyReLU(inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(16, 32, 3), nn.LeakyReLU(inplace=True), nn.Conv2d(32, 3, 3), ) self.debug_viz = viz.BatchVisualizer("batch", env="mosaic_debug") self.debug_viz2 = viz.BatchVisualizer("batch2", env="mosaic_debug") self.debug = False
def __init__(self, model, reference, num_batches, val_loader, env=None): self.model = model self.reference = reference self.num_batches = num_batches self.val_loader = val_loader self.viz = viz.BatchVisualizer("demosaick", env=env) self.loss_viz = viz.ScalarVisualizer( "loss", opts={"legend": ["train", "val"]}, env=env) self.psnr_viz = viz.ScalarVisualizer( "psnr", opts={"legend": ["train", "train_g", "val"]}, env=env) self.ssim_viz = viz.ScalarVisualizer( "1-ssim", opts={"legend": ["train", "val"]}, env=env) self.l1_viz = viz.ScalarVisualizer( "l1", opts={"legend": ["train", "val"]}, env=env) self.current_epoch = 0
def __init__(self, data, model, ref, env=None, batch_size=16, shuffle=False, cuda=True, period=500): super(DemosaicVizCallback, self).__init__() self.batch_size = batch_size self.model = model self.ref = ref self.batch_viz = viz.BatchVisualizer("batch", env=env) self._cuda = cuda self.loader = DataLoader( data, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=True) self.period = period self.counter = 0 self.psnr = losses.PSNR(crop=8) self.grads = ImageGradients(3).cuda()
def __init__(self, data, model, env=None, batch_size=8, shuffle=False, cuda=True, period=100): super(DemosaicVizCallback, self).__init__() self.batch_size = batch_size self.model = model self.batch_viz = viz.BatchVisualizer("batch", env=env) self._cuda = cuda self.loader = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=True) self.period = period self.counter = 0
def __init__(self, model, ref_model, val_loader, cuda, env=None): self.model = model self.ref_model = ref_model self.val_loader = val_loader self.cuda = cuda self.viz = viz.BatchVisualizer("denoise", port=args.port, env=env) self.loss_viz = viz.ScalarVisualizer("loss", port=args.port, env=env) self.psnr_viz = viz.ScalarVisualizer("psnr", port=args.port, env=env) self.val_loss_viz = viz.ScalarVisualizer("val_loss", port=args.port, env=env) self.val_psnr_viz = viz.ScalarVisualizer("val_psnr", port=args.port, env=env) self.ref_loss_viz = viz.ScalarVisualizer("ref_loss", port=args.port, env=env) self.ref_psnr_viz = viz.ScalarVisualizer("ref_psnr", port=args.port, env=env)
def __init__(self, model, ref_model, val_loader, cuda, env=None): self.model = model self.ref_model = ref_model self.val_loader = val_loader self.cuda = cuda self.viz = viz.BatchVisualizer("deconv", port=args.port, env=env) self.psf_viz = viz.BatchVisualizer("psf", port=args.port, env=env) self.data_kernels0_viz = viz.BatchVisualizer("data_kernels0", port=args.port, env=env) self.data_kernels1_viz = viz.BatchVisualizer("data_kernels1", port=args.port, env=env) self.data_kernel_weights0_viz = viz.ScalarVisualizer( "data_kernel_weights0", ntraces=self.model.data_kernel_weights.shape[1], port=args.port, env=env) self.data_kernel_weights1_viz = viz.ScalarVisualizer( "data_kernel_weights1", ntraces=self.model.data_kernel_weights.shape[1], port=args.port, env=env) self.reg_kernels0_viz = viz.BatchVisualizer("reg_kernels0", port=args.port, env=env) self.reg_kernels1_viz = viz.BatchVisualizer("reg_kernels1", port=args.port, env=env) self.reg_kernel_weights0_viz = viz.ScalarVisualizer( "reg_kernel_weights0", ntraces=self.model.reg_kernel_weights.shape[1], port=args.port, env=env) self.reg_kernel_weights1_viz = viz.ScalarVisualizer( "reg_kernel_weights1", ntraces=self.model.reg_kernel_weights.shape[1], port=args.port, env=env) self.filter_s_viz = viz.ScalarVisualizer( "filter_s0", ntraces=self.model.filter_s.shape[1], port=args.port, env=env) self.filter_r_viz = viz.ScalarVisualizer( "filter_r0", ntraces=self.model.filter_r.shape[1], port=args.port, env=env) self.reg_thresholds_viz = viz.ScalarVisualizer( "reg_thresholds", ntraces=self.model.reg_thresholds.shape[1], port=args.port, env=env) self.loss_viz = viz.ScalarVisualizer("loss", port=args.port, env=env) self.psnr_viz = viz.ScalarVisualizer("psnr", port=args.port, env=env) self.val_loss_viz = viz.ScalarVisualizer("val_loss", port=args.port, env=env) self.val_psnr_viz = viz.ScalarVisualizer("val_psnr", port=args.port, env=env) self.ref_loss_viz = viz.ScalarVisualizer("ref_loss", port=args.port, env=env) self.ref_psnr_viz = viz.ScalarVisualizer("ref_psnr", port=args.port, env=env)
def main(args, params): data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor()) val_data = dataset.MattingDataset(args.data_dir, transform=dataset.ToTensor()) if len(data) == 0: log.info("no input files found, aborting.") return dataloader = DataLoader(data, batch_size=1, shuffle=True, num_workers=4) val_dataloader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=0) log.info("Training with {} samples".format(len(data))) # Starting checkpoint file checkpoint = os.path.join(args.output, "checkpoint.ph") if args.checkpoint is not None: checkpoint = args.checkpoint chkpt = None if os.path.isfile(checkpoint): log.info("Resuming from checkpoint {}".format(checkpoint)) chkpt = th.load(checkpoint) params = chkpt['params'] # override params log.info("Model parameters: {}".format(params)) model = modules.get(params) # loss_fn = modules.CharbonnierLoss() loss_fn = modules.AlphaLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if not os.path.exists(args.output): os.makedirs(args.output) global_step = 0 if chkpt is not None: model.load_state_dict(chkpt['model_state']) optimizer.load_state_dict(chkpt['optimizer']) global_step = chkpt['step'] # Destination checkpoint file checkpoint = os.path.join(args.output, "checkpoint.ph") name = os.path.basename(args.output) loss_viz = viz.ScalarVisualizer("loss", env=name) image_viz = viz.BatchVisualizer("images", env=name) matte_viz = viz.BatchVisualizer("mattes", env=name) weights_viz = viz.BatchVisualizer("weights", env=name) trimap_viz = viz.BatchVisualizer("trimap", env=name) log.info("Model: {}\n".format(model)) model.cuda() loss_fn.cuda() log.info("Starting training from step {}".format(global_step)) smooth_loss = 0 smooth_loss_ifm = 0 smooth_time = 0 ema_alpha = 0.9 last_checkpoint_time = time.time() try: epoch = 0 while True: # Train for one epoch for step, batch in enumerate(dataloader): batch_start = time.time() frac_epoch = epoch+1.0*step/len(dataloader) batch_v = make_variable(batch, cuda=True) optimizer.zero_grad() output = model(batch_v) target = crop_like(batch_v['matte'], output) ifm = crop_like(batch_v['vanilla'], output) loss = loss_fn(output, target) loss_ifm = loss_fn(ifm, target) loss.backward() # th.nn.utils.clip_grad_norm(model.parameters(), 1e-1) optimizer.step() global_step += 1 batch_end = time.time() smooth_loss = (1.0-ema_alpha)*loss.data[0] + ema_alpha*smooth_loss smooth_loss_ifm = (1.0-ema_alpha)*loss_ifm.data[0] + ema_alpha*smooth_loss_ifm smooth_time = (1.0-ema_alpha)*(batch_end-batch_start) + ema_alpha*smooth_time if global_step % args.log_step == 0: log.info("Epoch {:.1f} | loss = {:.7f} | {:.1f} samples/s".format( frac_epoch, smooth_loss, target.shape[0]/smooth_time)) if args.viz_step > 0 and global_step % args.viz_step == 0: model.train(False) for val_batch in val_dataloader: val_batchv = make_variable(val_batch, cuda=True) output = model(val_batchv) target = crop_like(val_batchv['matte'], output) vanilla = crop_like(val_batchv['vanilla'], output) val_loss = loss_fn(output, target) mini, maxi = target.min(), target.max() diff = (th.abs(output-target)) vizdata = th.cat((target, output, vanilla, diff), 0) vizdata = (vizdata-mini)/(maxi-mini) imgs = np.power(np.clip(vizdata.cpu().data, 0, 1), 1.0/2.2) image_viz.update(val_batchv['image'].cpu().data, per_row=1) trimap_viz.update(val_batchv['trimap'].cpu().data, per_row=1) weights = model.predicted_weights.permute(1, 0, 2, 3) new_w = [] means = [] var = [] for ii in range(weights.shape[0]): w = weights[ii:ii+1, ...] mu = w.mean() sigma = w.std() new_w.append(0.5*((w-mu)/(2*sigma)+1.0)) means.append(mu.data.cpu()[0]) var.append(sigma.data.cpu()[0]) weights = th.cat(new_w, 0) weights = th.clamp(weights, 0, 1) weights_viz.update(weights.cpu().data, caption="CM {:.4f} ({:.4f})| LOC {:.4f} ({:.4f}) | IU {:.4f} ({:.4f}) | KU {:.4f} ({:.4f})".format( means[0], var[0], means[1], var[1], means[2], var[2], means[3], var[3]), per_row=4) matte_viz.update( imgs, caption="Epoch {:.1f} | loss = {:.6f} | target, output, vanilla, diff".format( frac_epoch, val_loss.data[0]), per_row=4) log.info(" viz at step {}, loss = {:.6f}".format(global_step, val_loss.cpu().data[0])) break # Only one batch for validation losses = [smooth_loss, smooth_loss_ifm] legend = ["ours", "ref_ifm"] loss_viz.update(frac_epoch, losses, legend=legend) model.train(True) if batch_end-last_checkpoint_time > args.checkpoint_interval: last_checkpoint_time = time.time() save(checkpoint, model, params, optimizer, global_step) epoch += 1 if args.epochs > 0 and epoch >= args.epochs: log.info("Ending training at epoch {} of {}".format(epoch, args.epochs)) break except KeyboardInterrupt: log.info("training interrupted at step {}".format(global_step)) checkpoint = os.path.join(args.output, "on_stop.ph") save(checkpoint, model, params, optimizer, global_step)
def main(args): linearize = False if args.xtrans: period = 6 data = dset.XtransDataset(args.data_dir, transform=None, augment=False, linearize=linearize) else: period = 2 data = dset.BayerDataset(args.data_dir, transform=None, augment=False, linearize=linearize) loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, num_workers=8) mask_viz = viz.BatchVisualizer("mask", env="demosaic_inspect") mos_viz = viz.BatchVisualizer("mosaic", env="demosaic_inspect") diff_viz = viz.BatchVisualizer("diff", env="demosaic_inspect") target_viz = viz.BatchVisualizer("target", env="demosaic_inspect") input_hist = viz.HistogramVisualizer("color_hist", env="demosaic_inspect") for sample in loader: mosaic = sample["mosaic"] mask = sample["mask"] pad = args.ksize // 2 dx = (pad - args.offset_x) % period dy = (pad - args.offset_y) % period print("dx {} dy {}".format(dx, dy)) mosaic = mosaic[..., dy:, dx:] mask = mask[..., dy:, dx:] def to_patches(arr): patches = arr.unfold(2, args.ksize, period).unfold(3, args.ksize, period) bs, c, h, w, _, _ = patches.shape patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous() patches = patches.view(bs * h * w, c, args.ksize, args.ksize) return patches patches = to_patches(mosaic) bs = patches.shape[0] means = patches.view(bs, -1).mean(-1).view(bs, 1, 1, 1) std = patches.view(bs, -1).std(-1).view(bs, 1, 1, 1) print(means.min().item(), means.max().item()) patches -= means patches /= std + 1e-8 new_bs = 1024 idx = np.random.randint(0, patches.shape[0], (new_bs, )) patches = patches[idx] import torchlib.debug as D D.tensor(patches) flat = patches.view(new_bs, -1).cpu().numpy() nclusts = 16 clst = cluster.MiniBatchKMeans(n_clusters=nclusts) # clst.fit(flat) clst_idx = clst.fit_predict(flat) colors = np.random.uniform(size=(nclusts, 3)) manif = manifold.TSNE(n_components=2) new_coords = manif.fit_transform(flat) color = np.zeros((new_coords.shape[0], 3)) color = (colors[clst_idx, :] * 255).astype(np.uint8) print(color.shape) D.scatter(th.from_numpy(new_coords[:, 0]), th.from_numpy(new_coords[:, 1]), color=color, key="tsne") centers = th.from_numpy(clst.cluster_centers_).view( nclusts, 3, args.ksize, args.ksize) D.tensor(centers, "centers") for cidx in range(nclusts): idx = clst_idx == cidx p = th.from_numpy(patches.numpy()[idx]) D.tensor(p, key="cluster_{:02d}".format(cidx)) import ipdb ipdb.set_trace()