Пример #1
0
    def fuse_data(self, imgs):
        fuse_images = [[] for _ in range(len(self.aug_size))]
        for cnt, img in enumerate(imgs):
            rgb = skimage.transform.resize(
                data_utils.change_channel_order(img[0, :, :, :],
                                                to_channel_last=True),
                (self.fuse_size, self.fuse_size))
            if cnt % self.copy_per_img == 0:
                fuse_images[cnt // self.copy_per_img].append(rgb)
            elif cnt % self.copy_per_img == 1:
                fuse_images[cnt // self.copy_per_img].append(rgb[::-1, :, :])
            elif cnt % self.copy_per_img == 2:
                fuse_images[cnt // self.copy_per_img].append(rgb[:, ::-1, :])
            elif cnt % self.copy_per_img == 3:
                fuse_images[cnt // self.copy_per_img].append(
                    np.rot90(rgb, k=-1))
            elif cnt % self.copy_per_img == 4:
                fuse_images[cnt // self.copy_per_img].append(
                    np.rot90(rgb[::-1, :, :], k=-1))
            elif cnt % self.copy_per_img == 5:
                fuse_images[cnt // self.copy_per_img].append(
                    np.rot90(rgb[:, ::-1, :], k=-1))
        fuse_tps = [np.mean(np.stack(a, axis=0), axis=0) for a in fuse_images]

        if self.use_max:
            pred = np.max(np.stack(fuse_tps, axis=0), axis=0)
        else:
            pred = np.mean(np.stack(fuse_tps, axis=0), axis=0)
        return np.expand_dims(data_utils.change_channel_order(
            pred, to_channel_last=False),
                              axis=0)
Пример #2
0
def make_image_banner(imgs,
                      n_class,
                      mean,
                      std,
                      max_ind=(2, ),
                      decode_ind=(1, 2),
                      chanel_first=True):
    """
    Make image banner for the tensorboard
    :param imgs: list of images to display, each element has shape N * C * H * W
    :param n_class: the number of classes
    :param mean: mean used in normalization
    :param std: std used in normalization
    :param max_ind: indices of element in imgs to take max across the channel dimension
    :param decode_ind: indicies of element in imgs to decode the labels
    :param chanel_first: if True, the inputs are in channel first format
    :return:
    """
    for cnt in range(len(imgs)):
        if cnt in max_ind:
            # pred: N * C * H * W
            imgs[cnt] = np.argmax(imgs[cnt], 1)
        if cnt in decode_ind:
            # lbl map: N * 1 * H * W
            imgs[cnt] = decode_label_map(imgs[cnt], n_class)
        if (cnt not in max_ind) and (cnt not in decode_ind):
            # rgb image: N * 3 * H * W
            imgs[cnt] = inv_normalize(
                data_utils.change_channel_order(imgs[cnt]), mean, std) * 255
    banner = np.concatenate(imgs, axis=2).astype(np.uint8)
    if chanel_first:
        banner = data_utils.change_channel_order(banner, False)
    return banner
Пример #3
0
 def infer_tile(self, model, rgb, grid_list, patch_size, tile_dim,
                tile_dim_pad, lbl_margin):
     tile_preds = []
     for patch in patch_extractor.patch_block(rgb, model.lbl_margin,
                                              grid_list, patch_size, False):
         patch_preds = []
         for aug_patch in self.ensembler.augment_data(patch):
             for tsfm in self.tsfm:
                 tsfm_image = tsfm(image=aug_patch)
                 aug_patch = tsfm_image['image']
             aug_patch = torch.unsqueeze(aug_patch, 0).to(self.device)
             pred = F.softmax(model.inference(aug_patch),
                              1).detach().cpu().numpy()
             patch_preds.append(pred)
         tile_preds.append(
             data_utils.change_channel_order(
                 self.ensembler.fuse_data(patch_preds), True)[0, :, :, :])
     # stitch back to tiles
     tile_preds = patch_extractor.unpatch_block(
         np.array(tile_preds),
         tile_dim_pad,
         patch_size,
         tile_dim,
         [patch_size[0] - 2 * lbl_margin, patch_size[1] - 2 * lbl_margin],
         overlap=2 * lbl_margin)
     return tile_preds
Пример #4
0
def make_tb_image(img, lbl, pred, n_class, mean, std, chanel_first=True):
    """
    Make validation image for tensorboard
    :param img: the image to display, has shape N * 3 * H * W
    :param lbl: the label to display, has shape N * C * H * W
    :param pred: the pred to display has shape N * C * H * W
    :param n_class: the number of classes
    :param mean: mean used in normalization
    :param std: std used in normalization
    :param chanel_first: if True, the inputs are in channel first format
    :return:
    """
    pred = np.argmax(pred, 1)
    label_image = decode_label_map(lbl, n_class)
    pred_image = decode_label_map(pred, n_class)
    img_image = inv_normalize(data_utils.change_channel_order(img), mean,
                              std) * 255
    banner = np.concatenate([img_image, label_image, pred_image],
                            axis=2).astype(np.uint8)
    if chanel_first:
        banner = data_utils.change_channel_order(banner, False)
    return banner
Пример #5
0
    def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False,
              densecrf=False, crf_params=None):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        misc_utils.make_dir_if_not_exist(pred_dir)
        pbar = tqdm(self.rgb_files)
        for rgb_file in pbar:
            file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0]
            pbar.set_description('Inferring {}'.format(file_name))
            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                                                              lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin)

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(np.ascontiguousarray(
                    data_utils.change_channel_order(tile_preds, False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)

            if self.encode_func:
                pred_img = self.encode_func(tile_preds)
            else:
                pred_img = tile_preds

            if visualize:
                vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5))

            misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
Пример #6
0
        return self.ds_len[0]

    def __iter__(self):
        rand_idx = [np.random.permutation(np.arange(x)) for x in self.ds_len]
        for cnt in range(self.ds_len[0] // self.batch_size):
            for ds_cnt, n_sample in enumerate(self.ratio):
                for curr_cnt in range(n_sample):
                    yield self.offset[ds_cnt] + rand_idx[ds_cnt][
                        (cnt * n_sample + curr_cnt) % self.ds_len[ds_cnt]]


if __name__ == '__main__':
    from data import data_utils
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    tsfms = [A.RandomCrop(512, 512), ToTensorV2()]
    ds1 = RSDataLoader(r'/hdd/mrs/inria/ps512_pd0_ol0/patches',
                       r'/hdd/mrs/inria/ps512_pd0_ol0/file_list_train_kt.txt',
                       transforms=tsfms,
                       n_class=2)

    for cnt, data_dict in enumerate(ds1):
        from mrs_utils import vis_utils
        rgb, gt, cls = data_dict['image'], data_dict['mask'], data_dict['cls']
        print(cls)
        vis_utils.compare_figures([
            data_utils.change_channel_order(rgb.cpu().numpy()),
            gt.cpu().numpy()
        ], (1, 2),
                                  fig_size=(12, 5))
Пример #7
0
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False,
                 delta=1e-6,
                 eval_class=(1, ),
                 visualize=False,
                 densecrf=False,
                 crf_params=None,
                 verbose=True):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class))
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files):
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]

            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]
            lbl = misc_utils.load_file(lbl_file)
            if self.decode_func:
                lbl = self.decode_func(lbl)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [
                tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin
            ]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size,
                                                  overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(
                        m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                        lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size,
                                             tile_dim, tile_dim_pad,
                                             lbl_margin)

            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    scipy.special.softmax(tile_preds, axis=-1)[:, :, 1])

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(
                    np.ascontiguousarray(
                        data_utils.change_channel_order(
                            scipy.special.softmax(tile_preds, axis=-1),
                            False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q,
                                       axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)
            iou_score = metric_utils.iou_metric(lbl / self.truth_val,
                                                tile_preds,
                                                eval_class=eval_class)
            pstr, rstr = self.get_result_strings(file_name, iou_score, delta)
            tm.misc_utils.verb_print(pstr, verbose)
            report.append(rstr)
            iou_a += iou_score[0, :]
            iou_b += iou_score[1, :]
            if visualize:
                if self.encode_func:
                    vis_utils.compare_figures([
                        rgb,
                        self.encode_func(lbl),
                        self.encode_func(tile_preds)
                    ], (1, 3),
                                              fig_size=(15, 5))
                else:
                    vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3),
                                              fig_size=(15, 5))
            if pred_dir:
                if self.encode_func:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        self.encode_func(tile_preds))
                else:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        tile_preds)
        pstr, rstr = self.get_result_strings('Overall',
                                             np.stack([iou_a, iou_b], axis=0),
                                             delta)
        tm.misc_utils.verb_print(pstr, verbose)
        report.append(rstr)
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
        return np.mean(iou_a / (iou_b + delta)) * 100
Пример #8
0
    def fit_helper(self,
                   target_ds_dir,
                   target_ds_list,
                   device,
                   save_dir,
                   batch_size=1,
                   num_workers=4,
                   total_epoch=20):
        # create writer directory
        writer = SummaryWriter(log_dir=save_dir)

        self.generator.to(device)
        self.discriminator.to(device)
        self.criterion.to(device)

        d_meter, g_meter, all_meter = metric_utils.LossMeter(
        ), metric_utils.LossMeter(), metric_utils.LossMeter()

        source_loader = DataLoader(self.source_loader,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=num_workers,
                                   drop_last=True)
        target_loader = DataLoader(data_loader.RSDataLoader(
            target_ds_dir,
            target_ds_list,
            transforms=self.tsfms,
            with_label=False),
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=num_workers,
                                   drop_last=True)
        source_loader = data_loader.infi_loop_loader(source_loader)

        for epoch in range(total_epoch):
            for img_cnt, image_target in enumerate(
                    tqdm(target_loader,
                         desc='Epoch: {}/{}'.format(epoch, total_epoch))):
                image_target = image_target['image']
                image_source = next(source_loader)['image']

                image_source = (image_source / 127.5) - 1
                image_target = (image_target / 127.5) - 1
                image_source = Variable(image_source,
                                        requires_grad=True).to(device)
                image_target = Variable(image_target,
                                        requires_grad=True).to(device)

                valid = Variable(torch.FloatTensor(image_source.shape[0],
                                                   1).fill_(1.0),
                                 requires_grad=False).to(device)
                fake = Variable(torch.FloatTensor(image_source.shape[0],
                                                  1).fill_(0.0),
                                requires_grad=False).to(device)

                # generator
                self.optm_g.zero_grad()
                fake_imgs = self.generator(image_source)
                g_loss = self.criterion(self.discriminator(fake_imgs), valid)
                g_loss.backward()
                self.optm_g.step()

                # discriminator
                self.optm_d.zero_grad()
                real_loss = self.criterion(self.discriminator(image_target),
                                           valid)
                fake_loss = self.criterion(
                    self.discriminator(fake_imgs.detach()), fake)
                d_loss = 0.5 * (real_loss + fake_loss)
                d_loss.backward()
                self.optm_d.step()

                g_meter.update(g_loss, image_source.size(0))
                d_meter.update(d_loss, image_source.size(0))
                all_meter.update(g_loss + d_loss, image_source.size(0))

            print('loss_d: {:.3f}\tloss_g: {:.3f}\tloss_total: {:.3f} '.format(
                d_meter.get_loss(), g_meter.get_loss(), all_meter.get_loss()))

            banner_orig = np.floor(
                (image_source.detach().cpu().numpy() + 1) * 127.5)
            banner_fake = np.floor(
                (fake_imgs.detach().cpu().numpy() + 1) * 127.5)
            banner_real = np.floor(
                (image_target.detach().cpu().numpy() + 1) * 127.5)

            grid_orig = torchvision.utils.make_grid(
                torch.from_numpy(banner_orig)).cpu().numpy().astype(np.uint8)
            grid_fake = torchvision.utils.make_grid(
                torch.from_numpy(banner_fake)).cpu().numpy().astype(np.uint8)
            grid_real = torchvision.utils.make_grid(
                torch.from_numpy(banner_real)).cpu().numpy().astype(np.uint8)
            if grid_real.shape[0] != grid_orig.shape[0] or grid_real.shape[
                    1] != grid_orig.shape[1]:
                grid_real = data_utils.change_channel_order(
                    skimage.transform.resize(
                        data_utils.change_channel_order(grid_real),
                        (grid_orig.shape[1], grid_orig.shape[2]),
                        preserve_range=True).astype(np.uint8), False)
            grid_img = np.concatenate([grid_orig, grid_fake, grid_real],
                                      axis=2)

            writer.add_image('img', grid_img, epoch)
            writer.add_scalar('loss_d', d_meter.get_loss(), epoch)
            writer.add_scalar('loss_g', g_meter.get_loss(), epoch)
            writer.add_scalar('loss_total', all_meter.get_loss(), epoch)
            g_meter.reset()
            d_meter.reset()
            all_meter.reset()

        save_name = os.path.join(save_dir, 'model.pth.tar')
        torch.save(
            {
                'state_dict_d': self.discriminator.state_dict(),
                'state_dict_g': self.generator.state_dict(),
            }, save_name)
        print('Saved model at {}'.format(save_name))