def log_images(self, pl_module): ds = dataset.TreeDataset( csv_file=self.csv_file, root_dir=self.root_dir, transforms=dataset.get_transform(augment=False), label_dict=pl_module.label_dict) if self.n > len(ds): self.n = len(ds) ds = torch.utils.data.Subset(ds, np.arange(0, self.n, 1)) data_loader = torch.utils.data.DataLoader( ds, batch_size=1, shuffle=False, collate_fn=utilities.collate_fn) pl_module.model.eval() for batch in data_loader: paths, images, targets = batch if not pl_module.device.type == "cpu": images = [x.to(pl_module.device) for x in images] predictions = pl_module.model(images) for path, image, prediction, target in zip(paths, images, predictions, targets): image = image.permute(1, 2, 0) image = image.cpu() visualize.plot_prediction_and_targets(image=image, predictions=prediction, targets=target, image_name=path, savedir=self.savedir) plt.close() try: saved_plots = glob.glob("{}/*.png".format(self.savedir)) for x in saved_plots: pl_module.logger.experiment.log_image(x) except Exception as e: print( "Could not find logger in ligthning module, skipping upload, images were saved to {}, error was rasied {}" .format(self.savedir, e))
def test_plot_predictions_and_targets(m, tmpdir): ds = m.val_dataloader() batch = next(iter(ds)) paths, images, targets = batch m.model.eval() predictions = m.model(images) for path, image, target, prediction in zip(paths, images, targets, predictions): image = image.permute(1, 2, 0) save_figure_path = visualize.plot_prediction_and_targets( image, prediction, target, image_name=os.path.basename(path), savedir=tmpdir) assert os.path.exists(save_figure_path)