def grad_flow_check(self, model, loss, G=False): for opt in self.optimizers: opt.zero_grad() param_idx = [ i for i, a in enumerate([*model.modules()]) if instance_checker(a, nn.Conv2d) or instance_checker(a, nn.ConvTranspose2d) ][0] # store current param value sum_1 = list(model.modules())[param_idx].weight.sum().item() # perform backprop if G: loss.backward(retain_graph=True) else: loss.backward() for opt in self.optimizers: opt.step() sum_2 = list(model.modules())[param_idx].weight.sum().item() if (model.training): working = "working" if sum_1 != sum_2 else "\033[1;32m not working \033[0m" print( "\n", model.__class__.__name__ + f" is in training mode, BACKPROP is {working} \n") else: working = "blocked" if sum_1 == sum_2 else "\033[1;32m not blocked \033[0m" print( "\n", model.__class__.__name__ + f" is in eval mode, BACKPROP is {working} \n")
def visualize_results(dataloader, generator, calinration_net, discriminator, args, number_of_batches = 1): generator.eval() if isinstance(generator, GeneralVAE): if args.z_dim == 2: image = generator.plot_manifold(20, x1=0.1, x2=0.9) plt.imshow(image, interpolation="none") plt.show() plt.close() for i, (batch) in tqdm(enumerate(dataloader), total=len(dataloader)): if i >= number_of_batches: break if args.dataset == "LIDC": images, labels, gt_dist = unpack_batch(batch) gt_labels = None else: images, labels = unpack_batch(batch) gt_dist = None gt_labels = None if (args.dataset == "CITYSCAPES19" and args.class_flip): gt_labels = labels.clone() labels = torch.eye(LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute(0, 3, 1, 2) bb_preds = batch["bb_preds"].to(DEVICE).float() bb_preds = torch.eye(LABELS_CHANNELS)[bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) else: bb_preds = None calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, al_maps, gan_al_maps = test_forward_pass(images, labels, bb_preds, generator, calinration_net, discriminator, args) comparison_figure = plot_comparison_figure(batch, calnet_preds, fake_labels, al_maps, gan_al_maps, generator, calinration_net, discriminator, args) fig = plt.figure(1) ax = plt.gca() ax.axis('off') plt.imshow(comparison_figure, interpolation="none") plt.show() plt.close(fig) if instance_checker(generator, GeneralVAE): plotted_samples = generator.plot_sample_preds(images, labels, calnet_preds, pred_dist, gt_dist, n_preds=args.n_generator_samples_test, dataset=args.dataset) fig = plt.figure(3) ax = plt.gca() ax.axis('off') plt.imshow(plotted_samples, interpolation="none") plt.show() plt.close(fig)
def assert_type(expectedType, content): """ makes sure type is respected""" assert instance_checker(content, expectedType), f"{content} is not {expectedType}"
def compute_stats(args, generator, images, calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, gan_al_maps, labels, gt_dist, gt_labels, ignore_mask, b_index=0): #todo log in tensorboardx ged = None if not args.calibration_net == "EmptyCalNet": calnet_iou = [] for i in range(len(calnet_preds)): calnet_iou.append(compute_iou(calnet_preds[i], labels[i], args)) if args.dataset == "LIDC": mean_calnet_iou = nanmean(torch.tensor(calnet_iou)) else: mean_calnet_iou = nanmean(torch.stack(calnet_iou, dim=0)) if not args.debug and args.mode == "train": wandb.log({ "Calibration net predictions Mean IoU": mean_calnet_iou.cpu() }) if not args.generator == "EmptyGenerator": gen_iou = [] for i in range(len(fake_labels)): gen_iou.append(compute_iou(fake_labels[i], labels[i], args)) if args.dataset == "LIDC": mean_gen_iou = nanmean(torch.tensor(gen_iou)) else: mean_gen_iou = nanmean(torch.stack(gen_iou, dim=0)) if not args.debug and args.mode == "train": wandb.log({"Final predictions Mean IoU": mean_gen_iou.cpu()}) if args.dataset == "LIDC": if not args.generator == "EmptyGenerator": ged, _, d_YS, d_SS = compute_ged( pred_dist, gt_dist, calnet_preds, args=args, n_samples=args.n_generator_samples_test) ged = ged.mean() d_YS = d_YS.mean() d_SS = d_SS.mean() if not args.debug and args.mode == "train": wandb.log({"meanGED": ged.cpu()}) # todo make sure GED is correct! wandb.log({"meanYS": d_YS.cpu()}) wandb.log({"meanSS": d_SS.cpu()}) if args.dataset == "CITYSCAPES19": # get classes, corresponding flip classes and probabilities flip_args = eval(f"CITYSCAPES19_{args.flip_experiment}FLIP") class_1 = flip_args[0] class_2 = flip_args[1] flip_probs = flip_args[2] if args.generator == "EmptyGenerator": calnet_class_probs = compute_pred_class_probs( labels, calnet_preds, ignore_mask, args) gt_class_probs = compute_gt_class_probs(labels, args) with torch.no_grad(): calnet_precision = torch.abs(calnet_class_probs - gt_class_probs).sum() f_classes = class_1 + class_2 calnet_flip_precision = torch.abs( np.take(calnet_class_probs, f_classes) - np.take(gt_class_probs, f_classes)).sum() else: pred_class_probs = compute_pred_class_probs( labels, pred_dist.mean(0), ignore_mask, args) gt_class_probs = compute_gt_class_probs(labels, args) with torch.no_grad(): gen_precision = torch.abs(pred_class_probs - gt_class_probs).sum() f_classes = class_1 + class_2 gen_flip_precision = torch.abs( np.take(pred_class_probs, f_classes) - np.take(gt_class_probs, f_classes)).sum() if not args.debug and args.mode == "train": if args.generator == "EmptyGenerator": wandb.log({"calnet_precision": calnet_precision.cpu() }) # todo make sure GED is correct! wandb.log( {"calnet_flip_precision": calnet_flip_precision.cpu()}) else: wandb.log({"gen_precision": gen_precision.cpu() }) # todo make sure GED is correct! wandb.log({"gen_flip_precision": gen_flip_precision.cpu()}) if instance_checker(generator, GeneralVAE): label = gt_labels[:, 1, :, :] # rgb channels are identical so we extract only one of them gt_dist = get_all_modes(label, input_classes=class_1, target_classes=class_2, flip_probs=flip_probs, n_flipped_modes=len(flip_probs)) ged, _, d_YS, d_SS = compute_ged( pred_dist, gt_dist, calnet_preds, args=args, n_samples=args.n_generator_samples_test) ged = ged.mean() d_YS = d_YS.mean() d_SS = d_SS.mean() if not args.debug and args.mode == "train": wandb.log({"meanGED": ged.cpu()}) # todo make sure GED is correct! wandb.log({"meanYS": d_YS.cpu()}) wandb.log({"meanSS": d_SS.cpu()}) return ged
def validation_plots(batch, generator, calibration_net, discriminator, args, batch_idx=0): assert instance_checker(generator, GeneralGenerator) assert instance_checker(calibration_net, GeneralGenerator) assert instance_checker(discriminator, GeneralDiscriminator) if args.dataset == "LIDC": images, labels, gt_dist = unpack_batch(batch) gt_labels = None else: images, labels = unpack_batch(batch) gt_dist = None gt_labels = None if args.dataset == "CITYSCAPES19": bb_preds = batch["bb_preds"].to(DEVICE).float() bb_preds = torch.eye(LABELS_CHANNELS)[ bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) one_hot_labels = torch.eye(LABELS_CHANNELS)[ labels[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) overlapped_mask = get_cs_ignore_mask(bb_preds, one_hot_labels) else: bb_preds = None overlapped_mask = None if (args.dataset == "CITYSCAPES19" and args.class_flip): gt_labels = labels.clone() labels = torch.eye(LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute( 0, 3, 1, 2) calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, al_maps, gan_al_maps = test_forward_pass( images, labels, bb_preds, generator, calibration_net, discriminator, args) # save best calibration net if args.dataset == "LIDC": lab_dist = torch.eye(LABELS_CHANNELS)[(gt_dist).long()].permute( 1, 0, 4, 2, 3).to(DEVICE).mean(0) eps = 1e-7 kl = lambda p, q: (-p.clamp(min=eps, max=1 - eps) * torch.log( q.clamp(min=eps, max=1 - eps)) + p.clamp(min=eps, max=1 - eps) * torch.log(p.clamp(min=eps, max=1 - eps))).sum(1) calnet_score = kl(calnet_preds.detach(), lab_dist).mean() if args.generator == "EmptyGenerator" and (not args.debug): wandb.log({"Calnet score": calnet_score}) global BEST_CALNET_SCORE if args.mode == "train" and args.generator == "EmptyGenerator" and ( not args.debug ) and calnet_score is not None and calnet_score < BEST_CALNET_SCORE: BEST_CALNET_SCORE = calnet_score print( f"{PRINTCOLOR_GREEN} Saved New Best Calibration Net! {PRINTCOLOR_END}" ) save_models(discriminator, generator, calibration_net, f"Best_Model") # # log stats ged = compute_stats(args, generator, images, calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, gan_al_maps, labels, gt_dist, gt_labels, overlapped_mask, b_index=batch_idx) global BEST_GED if args.mode == "train" and ( not args.debug) and ged is not None and ged < BEST_GED: BEST_GED = ged print(f"{PRINTCOLOR_GREEN} Saved New Best Model! {PRINTCOLOR_END}") save_models(discriminator, generator, calibration_net, f"Best_Model") # Plots comparison_figure = plot_comparison_figure(batch, calnet_preds, fake_labels, al_maps, gan_al_maps, generator, calibration_net, discriminator, args) if args.dataset == "CAMVID" or args.dataset == "CITYSCAPES19": calibration_figure = plot_calibration_figure(labels, calnet_preds, pred_dist, overlapped_mask, args) if instance_checker(generator, GeneralVAE): plotted_samples = generator.plot_sample_preds( images, labels, calnet_preds, pred_dist, gt_dist, n_preds=args.n_generator_samples_test, dataset=args.dataset) if not args.debug: # save and log comparison_figure = torch.from_numpy( np.moveaxis(comparison_figure, -1, 0)).float() save_example_images(comparison_figure, batch_idx, "comparison", "png") wandb.log({ "Results": wandb.Image(vutils.make_grid(comparison_figure, normalize=True)) }) if args.dataset == "CITYSCAPES19": calibration_figure = torch.from_numpy( np.moveaxis(calibration_figure, -1, 0)).float() wandb.log({ "Calibration": wandb.Image( vutils.make_grid(calibration_figure, normalize=True)) }) if instance_checker(generator, GeneralVAE): plotted_samples = torch.from_numpy( np.moveaxis(plotted_samples, -1, 0)).float() wandb.log({ "Plotted samples": wandb.Image(vutils.make_grid(plotted_samples, normalize=True)) })
def evaluation(dataloader_test, generator, calibration_net, discriminator, args, number_of_batches=1, b_index=0, visualize=True, save=False, load=False, print_stats=True): if visualize: visualize_results(dataloader_test, generator, calibration_net, discriminator, args, number_of_batches=number_of_batches) else: avg_GED = [] avgYS = [] avgSS = [] total_pixel_mode_counts = 0 avg_calnet_class_probs = [] for i, (batch) in tqdm(enumerate(dataloader_test), total=len(dataloader_test)): if i >= number_of_batches: break else: if args.dataset == "LIDC": images, labels, gt_dist = unpack_batch(batch) gt_labels = None else: images, labels = unpack_batch(batch) gt_dist = None gt_labels = None if (args.dataset == "CITYSCAPES19" and args.class_flip): gt_labels = labels.clone() labels = torch.eye(LABELS_CHANNELS)[ labels[:, 1, :, :].long()].permute(0, 3, 1, 2) bb_preds = batch["bb_preds"].to(DEVICE).float() bb_preds = torch.eye(LABELS_CHANNELS)[ bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) overlapped_mask = get_cs_ignore_mask( bb_preds, labels) # get indexes of correct bb preds unlabelled_idxs = torch.where(labels.argmax( 1) == 24) # get indexes of unlabelled pixels else: bb_preds = None overlapped_mask = None unlabelled_idxs = None if args.mode == "test" and load: images, labels, calnet_preds, pred_dist = load_numpy_arrays( i, args, to_tensor=True) gt_labels = labels.argmax(1, keepdim=True) else: calibration_net.eval() generator.eval() with torch.no_grad(): # forward pass _, calnet_preds, calnet_labelled_imgs = calibration_net_forward_pass( calibration_net, images, bb_preds, unlabelled_idxs, args) g_input = images if args.calibration_net == "EmptyCalNet" else calnet_labelled_imgs pred_dist, _, _ = generator.sample( g_input, ign_idxs=unlabelled_idxs, n_samples=args.n_generator_samples_test) if save and not load: if args.dataset == "LIDC": s_labels = gt_dist else: s_labels = labels save_numpy_arrays(images, s_labels, calnet_preds, pred_dist, batch_id=i, args=args) if args.dataset == "LIDC": if not args.generator == "EmptyGenerator": ged, _, d_YS, d_SS = compute_ged( pred_dist, gt_dist, calnet_preds, args=args, g_input=images, n_samples=args.n_generator_samples_test) avg_GED.append(ged.mean()) avgYS.append(d_YS) avgSS.append(d_SS) print(f"\nGED_batch_{i} = {ged.mean().item()}") if args.dataset == "CITYSCAPES19" and args.class_flip: if instance_checker(generator, GeneralVAE): # get classes, corresponding flip classes and probabilities flip_args = eval( f"{args.dataset}_{args.flip_experiment}FLIP") class_1 = flip_args[0] class_2 = flip_args[1] flip_probs = flip_args[2] label = gt_labels[:, 0, :, :] # rgb channels are identical so we extract only one of them calnet_class_probs = compute_pred_class_probs( labels, calnet_preds, overlapped_mask, args) avg_calnet_class_probs.append(calnet_class_probs) # get all modes gt_dist = get_all_modes( label, input_classes=class_1, target_classes=class_2, flip_probs=flip_probs, n_flipped_modes=len(flip_probs)) ged, d_matrices, d_YS, d_SS = compute_ged( pred_dist, gt_dist, calnet_preds, args=args, n_samples=args.n_generator_samples_test) avg_GED.append(ged.mean()) print(f"\nGED_batch_{i} = ", ged.mean().item()) avgYS.append(d_YS) avgSS.append(d_SS) pixel_mode_counts = count_pixel_modes( labels, pred_dist, input_classes=class_1, target_classes=class_2) total_pixel_mode_counts += pixel_mode_counts.sum( 0) #mean over the batch dim # mGED = torch.stack(avg_GED, dim=0).mean(0) mGED = nanmean(torch.stack(avg_GED, dim=0), dim=0) avgYS = torch.stack(avgYS, dim=0).mean(0) avgSS = torch.stack(avgSS, dim=0).mean(0) mYS = nanmean(avgYS) mSS = nanmean(avgSS) if (args.dataset == "CAMVID" or args.dataset == "CITYSCAPES19") and args.class_flip: f_classes_idxs = torch.stack( (torch.LongTensor(class_1), torch.LongTensor(class_2)), dim=1) calnet_class_probs = torch.stack(avg_calnet_class_probs, dim=0).mean(0)[f_classes_idxs] normalized_pixel_mode_counts = total_pixel_mode_counts / total_pixel_mode_counts[:, 0].unsqueeze( 1).expand(total_pixel_mode_counts.shape) pred_class_probs = normalized_pixel_mode_counts[:, 1:] # 0 dim = total gt_mode_probs = get_mode_statistics( probabilities=flip_probs, n_flipped_modes=len(flip_probs))["probs"] # if save: save_results(gt_mode_probs, flip_probs, calnet_class_probs, pred_class_probs, mGED, mYS, mSS, args) # log stats if not args.debug and args.mode == "train": wandb.log({"meanGED": mGED.cpu()}) wandb.log({"meanYS": mYS.cpu()}) wandb.log({"meanSS": mSS.cpu()}) if print_stats: print("\n---------------------------------------------") print("\nmGED = ", mGED) print("\nmYS = ", mYS) print("\nmSS = ", mSS) if args.dataset == "CITYSCAPES19" and args.class_flip: print("\n---------------------------------------------") print("\nGt class probs = ", flip_probs) print("\nCalibration Net class probs = ", calnet_class_probs) print("\nRefinement Net class probs = ", pred_class_probs) print("\n---------------------------------------------")