def fuse_data(self, imgs): fuse_images = [[] for _ in range(len(self.aug_size))] for cnt, img in enumerate(imgs): rgb = skimage.transform.resize( data_utils.change_channel_order(img[0, :, :, :], to_channel_last=True), (self.fuse_size, self.fuse_size)) if cnt % self.copy_per_img == 0: fuse_images[cnt // self.copy_per_img].append(rgb) elif cnt % self.copy_per_img == 1: fuse_images[cnt // self.copy_per_img].append(rgb[::-1, :, :]) elif cnt % self.copy_per_img == 2: fuse_images[cnt // self.copy_per_img].append(rgb[:, ::-1, :]) elif cnt % self.copy_per_img == 3: fuse_images[cnt // self.copy_per_img].append( np.rot90(rgb, k=-1)) elif cnt % self.copy_per_img == 4: fuse_images[cnt // self.copy_per_img].append( np.rot90(rgb[::-1, :, :], k=-1)) elif cnt % self.copy_per_img == 5: fuse_images[cnt // self.copy_per_img].append( np.rot90(rgb[:, ::-1, :], k=-1)) fuse_tps = [np.mean(np.stack(a, axis=0), axis=0) for a in fuse_images] if self.use_max: pred = np.max(np.stack(fuse_tps, axis=0), axis=0) else: pred = np.mean(np.stack(fuse_tps, axis=0), axis=0) return np.expand_dims(data_utils.change_channel_order( pred, to_channel_last=False), axis=0)
def make_image_banner(imgs, n_class, mean, std, max_ind=(2, ), decode_ind=(1, 2), chanel_first=True): """ Make image banner for the tensorboard :param imgs: list of images to display, each element has shape N * C * H * W :param n_class: the number of classes :param mean: mean used in normalization :param std: std used in normalization :param max_ind: indices of element in imgs to take max across the channel dimension :param decode_ind: indicies of element in imgs to decode the labels :param chanel_first: if True, the inputs are in channel first format :return: """ for cnt in range(len(imgs)): if cnt in max_ind: # pred: N * C * H * W imgs[cnt] = np.argmax(imgs[cnt], 1) if cnt in decode_ind: # lbl map: N * 1 * H * W imgs[cnt] = decode_label_map(imgs[cnt], n_class) if (cnt not in max_ind) and (cnt not in decode_ind): # rgb image: N * 3 * H * W imgs[cnt] = inv_normalize( data_utils.change_channel_order(imgs[cnt]), mean, std) * 255 banner = np.concatenate(imgs, axis=2).astype(np.uint8) if chanel_first: banner = data_utils.change_channel_order(banner, False) return banner
def infer_tile(self, model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin): tile_preds = [] for patch in patch_extractor.patch_block(rgb, model.lbl_margin, grid_list, patch_size, False): patch_preds = [] for aug_patch in self.ensembler.augment_data(patch): for tsfm in self.tsfm: tsfm_image = tsfm(image=aug_patch) aug_patch = tsfm_image['image'] aug_patch = torch.unsqueeze(aug_patch, 0).to(self.device) pred = F.softmax(model.inference(aug_patch), 1).detach().cpu().numpy() patch_preds.append(pred) tile_preds.append( data_utils.change_channel_order( self.ensembler.fuse_data(patch_preds), True)[0, :, :, :]) # stitch back to tiles tile_preds = patch_extractor.unpatch_block( np.array(tile_preds), tile_dim_pad, patch_size, tile_dim, [patch_size[0] - 2 * lbl_margin, patch_size[1] - 2 * lbl_margin], overlap=2 * lbl_margin) return tile_preds
def make_tb_image(img, lbl, pred, n_class, mean, std, chanel_first=True): """ Make validation image for tensorboard :param img: the image to display, has shape N * 3 * H * W :param lbl: the label to display, has shape N * C * H * W :param pred: the pred to display has shape N * C * H * W :param n_class: the number of classes :param mean: mean used in normalization :param std: std used in normalization :param chanel_first: if True, the inputs are in channel first format :return: """ pred = np.argmax(pred, 1) label_image = decode_label_map(lbl, n_class) pred_image = decode_label_map(pred, n_class) img_image = inv_normalize(data_utils.change_channel_order(img), mean, std) * 255 banner = np.concatenate([img_image, label_image, pred_image], axis=2).astype(np.uint8) if chanel_first: banner = data_utils.change_channel_order(banner, False) return banner
def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False, densecrf=False, crf_params=None): if isinstance(model, list) or isinstance(model, tuple): lbl_margin = model[0].lbl_margin else: lbl_margin = model.lbl_margin if crf_params is None and densecrf: crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5} misc_utils.make_dir_if_not_exist(pred_dir) pbar = tqdm(self.rgb_files) for rgb_file in pbar: file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0] pbar.set_description('Inferring {}'.format(file_name)) # read data rgb = misc_utils.load_file(rgb_file)[:, :, :3] # evaluate on tiles tile_dim = rgb.shape[:2] tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin] grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap) if isinstance(model, list) or isinstance(model, tuple): tile_preds = 0 for m in model: tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin) else: tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin) if densecrf: d = dcrf.DenseCRF2D(*tile_preds.shape) U = unary_from_softmax(np.ascontiguousarray( data_utils.change_channel_order(tile_preds, False))) d.setUnaryEnergy(U) d.addPairwiseBilateral(rgbim=rgb, **crf_params) Q = d.inference(5) tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2]) else: tile_preds = np.argmax(tile_preds, -1) if self.encode_func: pred_img = self.encode_func(tile_preds) else: pred_img = tile_preds if visualize: vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5)) misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
return self.ds_len[0] def __iter__(self): rand_idx = [np.random.permutation(np.arange(x)) for x in self.ds_len] for cnt in range(self.ds_len[0] // self.batch_size): for ds_cnt, n_sample in enumerate(self.ratio): for curr_cnt in range(n_sample): yield self.offset[ds_cnt] + rand_idx[ds_cnt][ (cnt * n_sample + curr_cnt) % self.ds_len[ds_cnt]] if __name__ == '__main__': from data import data_utils import albumentations as A from albumentations.pytorch import ToTensorV2 tsfms = [A.RandomCrop(512, 512), ToTensorV2()] ds1 = RSDataLoader(r'/hdd/mrs/inria/ps512_pd0_ol0/patches', r'/hdd/mrs/inria/ps512_pd0_ol0/file_list_train_kt.txt', transforms=tsfms, n_class=2) for cnt, data_dict in enumerate(ds1): from mrs_utils import vis_utils rgb, gt, cls = data_dict['image'], data_dict['mask'], data_dict['cls'] print(cls) vis_utils.compare_figures([ data_utils.change_channel_order(rgb.cpu().numpy()), gt.cpu().numpy() ], (1, 2), fig_size=(12, 5))
def evaluate(self, model, patch_size, overlap, pred_dir=None, report_dir=None, save_conf=False, delta=1e-6, eval_class=(1, ), visualize=False, densecrf=False, crf_params=None, verbose=True): if isinstance(model, list) or isinstance(model, tuple): lbl_margin = model[0].lbl_margin else: lbl_margin = model.lbl_margin if crf_params is None and densecrf: crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5} iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class)) report = [] if pred_dir: misc_utils.make_dir_if_not_exist(pred_dir) for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files): file_name = os.path.splitext(os.path.basename(lbl_file))[0] # read data rgb = misc_utils.load_file(rgb_file)[:, :, :3] lbl = misc_utils.load_file(lbl_file) if self.decode_func: lbl = self.decode_func(lbl) # evaluate on tiles tile_dim = rgb.shape[:2] tile_dim_pad = [ tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin ] grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap) if isinstance(model, list) or isinstance(model, tuple): tile_preds = 0 for m in model: tile_preds = tile_preds + self.infer_tile( m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin) else: tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin) if save_conf: misc_utils.save_file( os.path.join(pred_dir, '{}.npy'.format(file_name)), scipy.special.softmax(tile_preds, axis=-1)[:, :, 1]) if densecrf: d = dcrf.DenseCRF2D(*tile_preds.shape) U = unary_from_softmax( np.ascontiguousarray( data_utils.change_channel_order( scipy.special.softmax(tile_preds, axis=-1), False))) d.setUnaryEnergy(U) d.addPairwiseBilateral(rgbim=rgb, **crf_params) Q = d.inference(5) tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2]) else: tile_preds = np.argmax(tile_preds, -1) iou_score = metric_utils.iou_metric(lbl / self.truth_val, tile_preds, eval_class=eval_class) pstr, rstr = self.get_result_strings(file_name, iou_score, delta) tm.misc_utils.verb_print(pstr, verbose) report.append(rstr) iou_a += iou_score[0, :] iou_b += iou_score[1, :] if visualize: if self.encode_func: vis_utils.compare_figures([ rgb, self.encode_func(lbl), self.encode_func(tile_preds) ], (1, 3), fig_size=(15, 5)) else: vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3), fig_size=(15, 5)) if pred_dir: if self.encode_func: misc_utils.save_file( os.path.join(pred_dir, '{}.png'.format(file_name)), self.encode_func(tile_preds)) else: misc_utils.save_file( os.path.join(pred_dir, '{}.png'.format(file_name)), tile_preds) pstr, rstr = self.get_result_strings('Overall', np.stack([iou_a, iou_b], axis=0), delta) tm.misc_utils.verb_print(pstr, verbose) report.append(rstr) if report_dir: misc_utils.make_dir_if_not_exist(report_dir) misc_utils.save_file(os.path.join(report_dir, 'result.txt'), report) return np.mean(iou_a / (iou_b + delta)) * 100
def fit_helper(self, target_ds_dir, target_ds_list, device, save_dir, batch_size=1, num_workers=4, total_epoch=20): # create writer directory writer = SummaryWriter(log_dir=save_dir) self.generator.to(device) self.discriminator.to(device) self.criterion.to(device) d_meter, g_meter, all_meter = metric_utils.LossMeter( ), metric_utils.LossMeter(), metric_utils.LossMeter() source_loader = DataLoader(self.source_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) target_loader = DataLoader(data_loader.RSDataLoader( target_ds_dir, target_ds_list, transforms=self.tsfms, with_label=False), batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) source_loader = data_loader.infi_loop_loader(source_loader) for epoch in range(total_epoch): for img_cnt, image_target in enumerate( tqdm(target_loader, desc='Epoch: {}/{}'.format(epoch, total_epoch))): image_target = image_target['image'] image_source = next(source_loader)['image'] image_source = (image_source / 127.5) - 1 image_target = (image_target / 127.5) - 1 image_source = Variable(image_source, requires_grad=True).to(device) image_target = Variable(image_target, requires_grad=True).to(device) valid = Variable(torch.FloatTensor(image_source.shape[0], 1).fill_(1.0), requires_grad=False).to(device) fake = Variable(torch.FloatTensor(image_source.shape[0], 1).fill_(0.0), requires_grad=False).to(device) # generator self.optm_g.zero_grad() fake_imgs = self.generator(image_source) g_loss = self.criterion(self.discriminator(fake_imgs), valid) g_loss.backward() self.optm_g.step() # discriminator self.optm_d.zero_grad() real_loss = self.criterion(self.discriminator(image_target), valid) fake_loss = self.criterion( self.discriminator(fake_imgs.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() self.optm_d.step() g_meter.update(g_loss, image_source.size(0)) d_meter.update(d_loss, image_source.size(0)) all_meter.update(g_loss + d_loss, image_source.size(0)) print('loss_d: {:.3f}\tloss_g: {:.3f}\tloss_total: {:.3f} '.format( d_meter.get_loss(), g_meter.get_loss(), all_meter.get_loss())) banner_orig = np.floor( (image_source.detach().cpu().numpy() + 1) * 127.5) banner_fake = np.floor( (fake_imgs.detach().cpu().numpy() + 1) * 127.5) banner_real = np.floor( (image_target.detach().cpu().numpy() + 1) * 127.5) grid_orig = torchvision.utils.make_grid( torch.from_numpy(banner_orig)).cpu().numpy().astype(np.uint8) grid_fake = torchvision.utils.make_grid( torch.from_numpy(banner_fake)).cpu().numpy().astype(np.uint8) grid_real = torchvision.utils.make_grid( torch.from_numpy(banner_real)).cpu().numpy().astype(np.uint8) if grid_real.shape[0] != grid_orig.shape[0] or grid_real.shape[ 1] != grid_orig.shape[1]: grid_real = data_utils.change_channel_order( skimage.transform.resize( data_utils.change_channel_order(grid_real), (grid_orig.shape[1], grid_orig.shape[2]), preserve_range=True).astype(np.uint8), False) grid_img = np.concatenate([grid_orig, grid_fake, grid_real], axis=2) writer.add_image('img', grid_img, epoch) writer.add_scalar('loss_d', d_meter.get_loss(), epoch) writer.add_scalar('loss_g', g_meter.get_loss(), epoch) writer.add_scalar('loss_total', all_meter.get_loss(), epoch) g_meter.reset() d_meter.reset() all_meter.reset() save_name = os.path.join(save_dir, 'model.pth.tar') torch.save( { 'state_dict_d': self.discriminator.state_dict(), 'state_dict_g': self.generator.state_dict(), }, save_name) print('Saved model at {}'.format(save_name))