def save_images(image, mask, output_path, image_file, palette, original_size): # Saves the image, the model output and the results after the post processing zero_pad = 256 * 3 - len(palette) for i in range(zero_pad): palette.append(0) w, h = image.size if original_size: w, h = original_size image_file = os.path.basename(image_file).split('.')[0] colorized_mask = colorize_mask(mask, palette) if image.size != original_size: image = image.resize(size=original_size, resample=Image.BILINEAR) if colorized_mask.size != original_size: colorized_mask = colorized_mask.resize(size=original_size, resample=Image.NEAREST) blend = Image.blend(image, colorized_mask.convert('RGB'), 0.5) colorized_mask.save(os.path.join(output_path, image_file + '.png')) output_im = Image.new('RGB', (w * 3, h)) output_im.paste(image, (0, 0)) output_im.paste(colorized_mask, (w * 1, 0)) output_im.paste(blend, (w * 2, 0)) output_im.save(os.path.join(output_path, image_file + '_colorized.png'))
def save_images(output, output_path, name, palette): # Saves the image, the model output and the results after the post processing mask = output.detach().squeeze(0).cpu().numpy() mask = F.softmax(torch.from_numpy(mask), dim=0).argmax(0).cpu().numpy() w, h = mask.shape colorized_mask = colorize_mask(mask, palette) colorized_mask.save(os.path.join(output_path, name + '.png'))
def save_images(image, mask, output_path, image_file, palette): # Saves the image, the model output and the results after the post processing w, h = image.size image_file = os.path.basename(image_file).split('.')[0] colorized_mask = colorize_mask(mask, palette) colorized_mask.save(os.path.join(output_path, image_file + '_mask.png')) image.save(os.path.join(output_path, image_file + '_orig.png'))
def main(): args = parse_arguments() # CONFIG assert args.config config = json.load(open(args.config)) scales = [0.5, 0.75, 1.0, 1.25, 1.5] # DATA testdataset = testDataset(args.images) loader = DataLoader(testdataset, batch_size=1, shuffle=False, num_workers=1) num_classes = config['num_classes'] palette = get_voc_pallete(num_classes) # MODEL config['model']['supervised'] = True config['model']['semi'] = False model = models.CCT(num_classes=num_classes, conf=config['model'], testing=True) checkpoint = torch.load(args.model) model = torch.nn.DataParallel(model) try: model.load_state_dict(checkpoint['state_dict'], strict=True) except Exception as e: print(f'Some modules are missing: {e}') model.load_state_dict(checkpoint['state_dict'], strict=False) model.eval() model.cuda() #if args.save and not os.path.exists('outputs'): # os.makedirs('outputs') if not os.path.exists(args.save): os.makedirs(args.save) # LOOP OVER THE DATA tbar = tqdm(loader, ncols=100) total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 labels, predictions = [], [] for index, data in enumerate(tbar): image, image_id = data image = image.cuda() # PREDICT with torch.no_grad(): output = multi_scale_predict(model, image, scales, num_classes) prediction = np.asarray(np.argmax(output, axis=0), dtype=np.uint8) # SAVE RESULTS prediction_im = colorize_mask(prediction, palette) prediction_im.save(args.save + '/' + image_id[0] + '.png')
def _add_img_tb(self, val_visual, wrt_mode): val_img = [] palette = self.val_loader.dataset.palette for imgs in val_visual: imgs = [self.restore_transform(i) if (isinstance(i, torch.Tensor) and len(i.shape) == 3) else colorize_mask(i, palette) for i in imgs] imgs = [i.convert('RGB') for i in imgs] imgs = [self.viz_transform(i) for i in imgs] val_img.extend(imgs) val_img = torch.stack(val_img, 0) val_img = make_grid(val_img.cpu(), nrow=val_img.size(0)//len(val_visual), padding=5) self.writer.add_image(f'{wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step)
def save_image(image, epoch, id, palette): with torch.no_grad(): if image.shape[0] == 3: restore_transform = transforms.Compose([ DeNormalize(IMG_MEAN), transforms.ToPILImage()]) image = restore_transform(image) #image = PIL.Image.fromarray(np.array(image)[:, :, ::-1]) # BGR->RGB image.save(os.path.join('../visualiseImages/', str(epoch)+ id + '.png')) else: mask = image.numpy() colorized_mask = colorize_mask(mask, palette) colorized_mask.save(os.path.join('../visualiseImages/', str(epoch)+ id + '.png'))
def save_images(image, mask, output_path, image_file, palette, original_size, output=None): # Saves the image, the model output and the results after the post processing zero_pad = 256 * 3 - len(palette) for i in range(zero_pad): palette.append(0) w, h = image.size if original_size: w, h =original_size if output: print(mask.shape) resize_mask = cv2.resize(mask, dsize=original_size, interpolation=cv2.INTER_NEAREST) print(resize_mask.shape) pc_0 = int(np.count_nonzero(resize_mask==0)) pc_1 = int(np.count_nonzero(resize_mask==1)) pc_2 = int(np.count_nonzero(resize_mask==2)) pc_3 = int(np.count_nonzero(resize_mask==3)) pc_total = pc_0 + pc_1 + pc_2 + pc_3 output["pc_0"] = pc_0 output["pc_1"] = pc_1 output["pc_2"] = pc_2 output["pc_3"] = pc_3 output["pc_total"] = pc_total image_file = os.path.basename(image_file).split('.')[0] colorized_mask = colorize_mask(mask, palette) if image.size != original_size: image = image.resize(size=original_size, resample=Image.BILINEAR) if colorized_mask.size != original_size: colorized_mask = colorized_mask.resize(size=original_size, resample=Image.NEAREST) blend = Image.blend(image, colorized_mask.convert('RGB'), 0.5) mask_path = os.path.join(output_path, image_file+'.png') colorized_mask.save(mask_path) output_im = Image.new('RGB', (w*3, h)) output_im.paste(image, (0,0)) output_im.paste(colorized_mask, (w*1,0)) output_im.paste(blend, (w*2,0)) blend_path = os.path.join(output_path, image_file+'_colorized.png') output_im.save(blend_path) if output: output['mask'] = mask_path output['blend'] = blend_path
def save_images(image, mask, output_path, image_file, palette): # Saves the image, the model output and the results after the post processing w, h = image.size image_file = os.path.basename(image_file).split('.')[0] colorized_mask = colorize_mask(mask, palette) pmask = np.array(colorized_mask) height, width = pmask.shape loc_x, loc_y = [], [] for i in range(height): for j in range(width): if pmask[i, j] != 0: loc_y.append(i) loc_x.append(j) colorized_mask.save(os.path.join(output_path, image_file + '.png')) output_im = Image.new('RGB', (w * 2, h)) output_im.paste(image, (0, 0)) output_im.paste(colorized_mask, (w, 0)) output_im.save(os.path.join(output_path, image_file + '_colorized.png'))
def _valid_epoch(self, epoch): if self.val_loader is None: self.logger.warning( 'Not data loader was passed for the validation step, No validation is performed !' ) return {} self.logger.info('\n###### EVALUATION ######') self.model.eval() self.wrt_mode = 'val' self._reset_metrics() tbar = tqdm(self.val_loader, ncols=130) with torch.no_grad(): val_visual = [] for batch_idx, (data, target) in enumerate(tbar): # data, target = data.to(self.device), target.to(self.device) # LOSS output = self.model(data) loss = self.loss(output, target) if isinstance(self.loss, torch.nn.DataParallel): loss = loss.mean() self.total_loss.update(loss.item()) seg_metrics = eval_metrics(output, target, self.num_classes) self._update_seg_metrics(*seg_metrics) # LIST OF IMAGE TO VIZ (15 images) if len(val_visual) < 15: target_np = target.data.cpu().numpy() output_np = output.data.max(1)[1].cpu().numpy() val_visual.append( [data[0].data.cpu(), target_np[0], output_np[0]]) # PRINT INFO pixAcc, mIoU, _ = self._get_seg_metrics().values() tbar.set_description( 'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |' .format(epoch, self.total_loss.average, pixAcc, mIoU)) # WRTING & VISUALIZING THE MASKS val_img = [] palette = self.train_loader.dataset.palette for d, t, o in val_visual: d = self.restore_transform(d) t, o = colorize_mask(t, palette), colorize_mask(o, palette) d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB') [d, t, o] = [self.viz_transform(x) for x in [d, t, o]] val_img.extend([d, t, o]) val_img = torch.stack(val_img, 0) val_img = make_grid(val_img.cpu(), nrow=3, padding=5) self.writer.add_image( '{}/inputs_targets_predictions'.format(self.wrt_mode), val_img, self.wrt_step) # METRICS TO TENSORBOARD self.wrt_step = (epoch) * len(self.val_loader) self.writer.add_scalar('{}/loss'.format(self.wrt_mode), self.total_loss.average, self.wrt_step) seg_metrics = self._get_seg_metrics() for k, v in list(seg_metrics.items())[:-1]: self.writer.add_scalar('{}/{}'.format(self.wrt_mode, k), v, self.wrt_step) log = {'val_loss': self.total_loss.average, **seg_metrics} return log
def _valid_epoch(self, epoch): if self.val_loader is None: self.logger.warning( 'Not data loader was passed for the validation step, No validation is performed !' ) return {} self.logger.info('\n###### EVALUATION ######') self.model.eval() self.wrt_mode = 'val' self._reset_metrics() tbar = tqdm(self.val_loader, ncols=130) with torch.no_grad(): val_visual = [] kernel_visual = [] for batch_idx, (data, target, distance) in enumerate(tbar): output, distance_loss = self.model(data, distance) kernel, coarse_estimation, learned_sdf = self.model.module.kernel_visualization( data) seg_loss = self.loss(output, target) loss = 5 * seg_loss + distance_loss self.total_loss.update(loss.item()) seg_metrics = eval_metrics(output, target, self.num_classes) self._update_seg_metrics(*seg_metrics) # LIST OF IMAGE TO VIZ (15 images) if len(val_visual) < 15: target_np = target.data.cpu().numpy() kernel_np = kernel[:, 1, :, :, ].data.cpu().numpy() output_np = output.data.max(1)[1].cpu().numpy() #print(conved.shape) learned_sdf_np = learned_sdf.data.cpu().numpy() ##B*k*W*H target_sdf_np = distance.data.cpu().numpy() ##B*k*W*H coarse_estimation_np = coarse_estimation[:, 1, :, :, ].data.cpu( ).numpy() val_visual.append([ data[0].data.cpu(), target_np[0], output_np[0] ]) ## RGB, Segmentation label, segmentation, target_sdf. kernel_visual.append([ data[0].data.cpu(), kernel_np[0], coarse_estimation_np[0], learned_sdf_np[0], target_sdf_np[0] ]) # PRINT INFO pixAcc, mIoU, _ = self._get_seg_metrics().values() tbar.set_description( 'EVAL ({}) | Seg_Loss: {:.3f}, Dis_Loss: {:.3f},PixelAcc: {:.2f}, Mean IoU: {:.2f} |' .format(epoch, seg_loss, distance_loss, pixAcc, mIoU)) # WRTING & VISUALIZING THE MASKS val_img = [] palette = self.train_loader.dataset.palette for rgb, label, seg in val_visual: rgb = self.restore_transform(rgb) label, seg = colorize_mask(label, palette), colorize_mask( seg, palette) rgb, label, seg = rgb.convert('RGB'), label.convert( 'RGB'), seg.convert('RGB') [rgb, label, seg] = [self.viz_transform(x) for x in [rgb, label, seg]] val_img.extend([rgb, label, seg]) kernel_img = [] for rgb, kernel, pro, learned_sdf, sdf in kernel_visual: ## RGB, kerenl,coarse_prediction, learned_field, target_field rgb, kernel, pro = self.restore_transform( rgb), self.kernel_transform(kernel), self.kernel_transform( pro) learned_sdf, sdf = self.learned_transform(learned_sdf, sdf) learned_sdf, sdf = self.to_PIL(learned_sdf), self.to_PIL(sdf) rgb, kernel, pro,learned_sdf,sdf= rgb.convert('RGB'), kernel.convert('RGB'), pro.convert('RGB'),\ learned_sdf.convert('RGB'),sdf.convert('RGB') [rgb, kernel, pro, learned_sdf, sdf] = [ self.viz_transform(x) for x in [rgb, kernel, pro, learned_sdf, sdf] ] kernel_img.extend([rgb, kernel, pro, learned_sdf, sdf]) val_img = torch.stack(val_img, 0) val_img = make_grid(val_img.cpu(), nrow=3, padding=5) kernel_img = torch.stack(kernel_img, 0) kernel_img = make_grid(kernel_img.cpu(), nrow=5, padding=5) self.writer.add_image( f'{self.wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step) self.writer.add_image(f'{self.wrt_mode}/kernel_predictions', kernel_img, self.wrt_step) # METRICS TO TENSORBOARD self.wrt_step = (epoch) * len(self.val_loader) self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(), self.wrt_step) self.writer.add_scalar(f'{self.wrt_mode}/seg_loss', 5 * seg_loss, self.wrt_step) self.writer.add_scalar(f'{self.wrt_mode}/dis_loss', distance_loss, self.wrt_step) seg_metrics = self._get_seg_metrics() for k, v in list(seg_metrics.items())[:-1]: self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step) log = {'val_loss': self.total_loss.average, **seg_metrics} return log
def _valid_epoch(self, epoch): if self.val_loader is None: self.logger.warning( 'Not data loader was passed for the validation step, No validation is performed !' ) return {} self.logger.info('\n###### EVALUATION ######') self.model.eval() self.wrt_mode = 'val' self._reset_metrics() tbar = tqdm(self.val_loader, ncols=130) with torch.no_grad(): val_visual = [] cls_total_pix_correct = np.zeros(self.num_classes) cls_total_pix_labeled = np.zeros(self.num_classes) for batch_idx, (data, target) in enumerate(tbar): #data, target = data.to(self.device), target.to(self.device) # LOSS output = self.model(data) if self.config['arch']['type'][:2] == 'IC': assert output[0].size()[2:] == target.size()[1:] assert output[0].size()[1] == self.num_classes loss = self.loss(output, target) output = output[0] elif self.config['arch']['type'][-3:] == 'OCR': assert output[0].size()[2:] == target.size()[1:] assert output[0].size()[1] == self.num_classes loss = self.loss(output[0], target) loss += self.loss(output[1], target) * 0.4 output = output[0] elif 'Nearest' in self.config['arch']['type']: assert output[0].size()[2:] == target.size()[1:] assert output[0].size()[1] == self.num_classes loss = self.loss(output[0], target) loss += self.loss(output[1], target) * 0.4 output = output[0] elif self.config['arch']['type'][:3] == 'Enc': assert output[0].size()[2:] == target.size()[1:] assert output[0].size()[1] == self.num_classes loss = self.loss(output, target) output = output[0] elif self.config['arch']['type'][:5] == 'DANet': assert output[0].size()[2:] == target.size()[1:] assert output[0].size()[1] == self.num_classes loss = self.loss(output[0], target) loss += self.loss(output[1], target) * 0.2 loss += self.loss(output[2], target) * 0.2 output = output[0] else: assert output.size()[2:] == target.size()[1:] assert output.size()[1] == self.num_classes loss = self.loss(output, target) if isinstance(self.loss, torch.nn.DataParallel): loss = loss.mean() self.total_loss.update(loss.item()) seg_metrics = eval_metrics(output, target, self.num_classes) self._update_seg_metrics(*seg_metrics) for i in range(self.num_classes): cls_pix_correct, cls_pix_labeled = batch_class_pixel_accuracy( output, target, i) cls_total_pix_correct[i] += cls_pix_correct cls_total_pix_labeled[i] += cls_pix_labeled # LIST OF IMAGE TO VIZ (15 images) if len(val_visual) < 15: target_np = target.data.cpu().numpy() output_np = output.data.max(1)[1].cpu().numpy() val_visual.append( [data[0].data.cpu(), target_np[0], output_np[0]]) # PRINT INFO pixAcc, mIoU, _ = self._get_seg_metrics().values() cls_pix_acc = np.round( cls_total_pix_correct / cls_total_pix_labeled, 3) tbar.set_description( 'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f}, cls_pix_acc: {} |' .format(epoch, self.total_loss.average, pixAcc, mIoU, str(cls_pix_acc))) # WRTING & VISUALIZING THE MASKS val_img = [] palette = self.train_loader.dataset.palette for d, t, o in val_visual: d = self.restore_transform(d) t, o = colorize_mask(t, palette), colorize_mask(o, palette) d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB') [d, t, o] = [self.viz_transform(x) for x in [d, t, o]] val_img.extend([d, t, o]) val_img = torch.stack(val_img, 0) val_img = make_grid(val_img.cpu(), nrow=3, padding=5) self.writer.add_image( f'{self.wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step) # METRICS TO TENSORBOARD self.wrt_step = (epoch) * len(self.val_loader) self.writer.add_scalar(f'{self.wrt_mode}/loss', self.total_loss.average, self.wrt_step) seg_metrics = self._get_seg_metrics() for k, v in list(seg_metrics.items())[:-1]: self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step) log = {'val_loss': self.total_loss.average, **seg_metrics} return log