def plot_images( model, logger, test_set, dest_task="normal", ood_images=None, show_masks=False, loss_models={}, preds_name=None, target_name=None, ood_name=None, ): from task_configs import get_task, ImageTask test_images, preds, targets, losses, _ = model.predict_with_data(test_set) if isinstance(dest_task, str): dest_task = get_task(dest_task) if show_masks and isinstance(dest_task, ImageTask): test_masks = ImageTask.build_mask(targets, dest_task.mask_val, tol=1e-3) logger.images(test_masks.float(), f"{dest_task}_masks", resize=64) dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger) dest_task.plot_func(targets, target_name or f"{dest_task.name}_target", logger) if ood_images is not None: ood_preds = model.predict(ood_images) dest_task.plot_func(ood_preds, ood_name or f"{dest_task.name}_ood_preds", logger) for name, loss_model in loss_models.items(): with torch.no_grad(): output = loss_model(preds, targets, test_images) if hasattr(output, "task"): output.task.plot_func(output, name, logger, resize=128) else: logger.images(output.clamp(min=0, max=1), name, resize=128)
def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0, prefix=""): sqrt2 = math.sqrt(2) cmap = get_cmap("jet") path_values = {} realities_map = {reality.name: reality for reality in realities} for name, config in (plot_names or self.plots.items()): paths = config["paths"] realities = config["realities"] for reality in realities: with torch.no_grad(): path_values[reality] = self.compute_paths( graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality]) pred_mu = torch.Tensor().cuda() pred_sig = torch.Tensor().cuda() paths_ = [ 'f(emboss4d(x))', 'f(grey(x))', 'f(sobel(x))', 'f(wav(x))', 'n(x)', 'f(gauss(x))', 'f(laplace(x))', 'f(sharp(x))' ] for p in paths_: if p in [ 'x', 'y^', 'bin(x)', 'grey(x)', 'emboss(x)', 'sobel(x)', 'stacked2reshade(x)', 'gauss(x)', 'laplace(x)', 'sharp(x)' ]: continue if 'y^' in p and reality == 'ood': continue path_values[reality][f'{p}_m'] = path_values[reality][ p][:, :1] pred_mu = torch.cat( (pred_mu, path_values[reality][f'{p}_m']), dim=1) path_values[reality][f'{p}_s'] = path_values[reality][ p][:, 1:2].exp() * sqrt2 pred_sig = torch.cat( (pred_sig, path_values[reality][f'{p}_s']), dim=1) del path_values[reality][p] lap_dist = torch.distributions.Laplace(loc=pred_mu, scale=pred_sig + 1e-15) pdfs = [] for i in range(len(paths_)): pdfs.append( (lap_dist.log_prob( path_values[reality][f'{paths_[i]}_m']).exp() * path_values[reality]['stacked2reshade(x)']).sum( 1, keepdim=True)) pi = torch.cat(pdfs, dim=1) onehot = torch.cuda.FloatTensor(pi.size()).fill_(0.) onehot.scatter_(1, pi.argmax(1, keepdim=True), 1.0) path_values[reality][f'stacked2reshade(x)_m'] = ( pred_mu * onehot).sum(1, keepdim=True) path_values[reality][f'stacked2reshade(x)_s'] = ( pred_sig * onehot).sum(1, keepdim=True) for i in [4, 0, 1, 2, 3, 5, 6, 7]: path_values[reality][ f'stacked2reshade(x)_w{i+1}'] = path_values[ reality][f'stacked2reshade(x)'][:, i:i + 1] del path_values[reality]['stacked2reshade(x)'] path_values[reality]['emboss4d(x)'] = path_values[reality][ 'emboss4d(x)'][:, :3] if reality is 'test': #compute error map mask_task = self.paths["y^"][-1] mask = ImageTask.build_mask(path_values[reality]["y^"], val=mask_task.mask_val) errors = (( path_values[reality]["y^"][:, :1] - path_values[reality]["stacked2reshade(x)_m"][:, :1] )**2).mean(dim=1, keepdim=True) errors = (3 * errors / (mask_task.variance)).clamp( min=0, max=1) log_errors = torch.log(errors + 1) log_errors = log_errors / log_errors.max() log_errors = torch.tensor(cmap( log_errors.cpu()))[:, 0].permute( (0, 3, 1, 2)).float()[:, 0:3] log_errors = log_errors.clamp(min=0, max=1).to(DEVICE) log_errors[~mask.expand_as(log_errors)] = 0.505 path_values[reality]['error'] = log_errors path_values[reality] = { k: v.clamp(min=0, max=1).cpu() for k, v in path_values[reality].items() } del path_values[reality][f'wav(x)'] # more processing def reshape_img_to_rows(x_): downsample = lambda x: F.interpolate( x.unsqueeze(0), scale_factor=0.5, mode='bilinear').squeeze(0) x_list = [downsample(x_[i]) for i in range(x_.size(0))] x = torch.cat(x_list, dim=-1) return x all_images = {} for reality in realities: all_imgs_reality = [] plot_name = '' keys_list = list(path_values[reality].keys()) keys_list.remove('x') keys_list.insert(0, 'x') if 'y^' in keys_list: keys_list.remove('y^') keys_list.insert(1, 'y^') if 'error' in keys_list: keys_list.remove('error') keys_list.insert(-5, 'error') for k in keys_list: plot_name += k + '|' img_row = reshape_img_to_rows(path_values[reality][k]) if img_row.size(0) == 1: img_row = img_row.repeat(3, 1, 1) all_imgs_reality.append(img_row) plot_name = plot_name[:-1] plot_name = plot_name.replace("stacked2reshade", "st") plot_name = plot_name.replace("emboss4d", "em") plot_name = plot_name.replace("grey", "gr") plot_name = plot_name.replace("sobel", "sb") plot_name = plot_name.replace("laplace", "lp") plot_name = plot_name.replace("gauss", "gs") plot_name = plot_name.replace("(x)", "") all_images[reality] = torch.cat(all_imgs_reality, dim=-2) return all_images
def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0, prefix=""): error_pairs = {"n(x)": "y^"} realities_map = {reality.name: reality for reality in realities} for name, config in (plot_names or self.plots.items()): paths = config["paths"] realities = config["realities"] images = [] error = False cmap = get_cmap("jet") first = True error_passed_ood = 0 for reality in realities: with torch.no_grad(): path_values = self.compute_paths( graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality]) shape = list(path_values[list(path_values.keys())[0]].shape) shape[1] = 3 for i, path in enumerate(paths): if path == 'depth': continue X = path_values.get(path, torch.zeros(shape, device=DEVICE)) if first: images += [[]] if reality is 'ood' and error_passed_ood == 0: images[i].append(X.clamp(min=0, max=1).expand(*shape)) elif reality is 'ood' and error_passed_ood == 1: images[i + 1].append( X.clamp(min=0, max=1).expand(*shape)) else: images[-1].append(X.clamp(min=0, max=1).expand(*shape)) if path in error_pairs: error = True if first: images += [[]] if error: Y = path_values.get(path, torch.zeros(shape, device=DEVICE)) Y_hat = path_values.get( error_pairs[path], torch.zeros(shape, device=DEVICE)) out_task = self.paths[path][-1] if self.target_task == "reshading": #Use depth mask Y_mask = path_values.get( "depth", torch.zeros(shape, device=DEVICE)) mask_task = self.paths["r(x)"][-1] mask = ImageTask.build_mask(Y_mask, val=mask_task.mask_val) else: mask = ImageTask.build_mask(Y_hat, val=out_task.mask_val) errors = ((Y - Y_hat)**2).mean(dim=1, keepdim=True) log_errors = torch.log( errors.clamp(min=0, max=out_task.variance)) errors = (3 * errors / (out_task.variance)).clamp( min=0, max=1) log_errors = torch.log(errors + 1) log_errors = log_errors / log_errors.max() log_errors = torch.tensor(cmap( log_errors.cpu()))[:, 0].permute( (0, 3, 1, 2)).float()[:, 0:3] log_errors = log_errors.clamp( min=0, max=1).expand(*shape).to(DEVICE) log_errors[~mask.expand_as(log_errors)] = 0.505 if reality is 'ood': images[i + 1].append(log_errors) error_passed_ood = 1 else: images[-1].append(log_errors) error = False first = False for i in range(0, len(images)): images[i] = torch.cat(images[i], dim=0) logger.images_grouped( images, f"{prefix}_{name}_[{', '.join(realities)}]_[{', '.join(paths)}]", resize=config["size"])
def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0, prefix=""): sqrt2 = math.sqrt(2) cmap = get_cmap("jet") path_values = {} realities_map = {reality.name: reality for reality in realities} for name, config in (plot_names or self.plots.items()): paths = config["paths"] realities = config["realities"] for reality in realities: with torch.no_grad(): # pdb.set_trace() path_values[reality] = self.compute_paths( graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality]) if reality is 'test': #compute error map mask_task = self.paths["y^"][-1] mask = ImageTask.build_mask(path_values[reality]["y^"], val=mask_task.mask_val) errors = ( (path_values[reality]["y^"][:, :3] - path_values[reality]["f(g(x))"][:, :3])**2).mean( dim=1, keepdim=True) errors = (3 * errors / (mask_task.variance)).clamp( min=0, max=1) log_errors = torch.log(errors + 1) log_errors = log_errors / log_errors.max() log_errors = torch.tensor(cmap( log_errors.cpu()))[:, 0].permute( (0, 3, 1, 2)).float()[:, 0:3] log_errors = log_errors.clamp(min=0, max=1).to(DEVICE) log_errors[~mask.expand_as(log_errors)] = 0.505 path_values[reality]['error'] = log_errors nchannels = path_values[reality]['f(g(x))'].size(1) // 2 path_values[reality]['f(g(x))_m'] = path_values[reality][ 'f(g(x))'][:, :nchannels] path_values[reality]['f(g(x))_s'] = path_values[reality][ 'f(g(x))'][:, nchannels:].exp() * sqrt2 path_values[reality]['f0(g(x))_m'] = path_values[reality][ 'f0(g(x))'][:, :nchannels] path_values[reality]['f0(g(x))_s'] = path_values[reality][ 'f0(g(x))'][:, nchannels:].exp() * sqrt2 path_values[reality] = { k: v.clamp(min=0, max=1).cpu() for k, v in path_values[reality].items() } del path_values[reality]['f(g(x))'] del path_values[reality]['f0(g(x))'] # del path_values[reality]['emboss(x)'] # more processing def reshape_img_to_rows(x_): downsample = lambda x: F.interpolate( x.unsqueeze(0), scale_factor=0.5, mode='bilinear').squeeze(0) x_list = [downsample(x_[i]) for i in range(x_.size(0))] x = torch.cat(x_list, dim=-1) return x all_images = {} for reality in realities: all_imgs_reality = [] plot_name = '' for k in path_values[reality].keys(): plot_name += k + '|' if path_values[reality][k].size(1) > 3: path_values[reality][k] = path_values[reality][k][:, :3] img_row = reshape_img_to_rows(path_values[reality][k]) if img_row.size(0) == 1: img_row = img_row.repeat(3, 1, 1) all_imgs_reality.append(img_row) plot_name = plot_name[:-1] all_images[reality] = torch.cat(all_imgs_reality, dim=-2) return all_images