예제 #1
0
def visualize_missclassifications(root_dir: Path,
                                  cfg_file: Path,
                                  net_file: Path,
                                  save_dir: Path = None):

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    # loading cfg and network
    cfg = load_cfg(cfg_file)
    net = load_net(cfg, net_file)

    dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True)
    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0,
        'shuffle': False,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    with torch.no_grad():
        net.eval()
        for step, batch in enumerate(dataloader):
            city = batch['city'][0]
            print(city)
            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)
            y_true = batch['label'].to(device)
            y_pred = net(t1_img, t2_img)
            y_pred = torch.sigmoid(y_pred)
            y_pred = y_pred.cpu().detach().numpy()[0, ]
            y_pred = y_pred > cfg.THRESH
            y_pred = y_pred.transpose((1, 2, 0)).astype('uint8')[:, :, 0]

            # label
            y_true = y_true.cpu().detach().numpy()[0, ]
            y_true = y_true.transpose((1, 2, 0)).astype('uint8')[:, :, 0]

            img = np.zeros((*y_true.shape, 3))
            true_positives = np.logical_and(y_pred, y_true)
            false_positives = np.logical_and(y_pred, np.logical_not(y_true))
            false_negatives = np.logical_and(np.logical_not(y_pred), y_true)
            img[true_positives, :] = [1, 1, 1]
            img[false_positives] = [0, 1, 0]
            img[false_negatives] = [1, 0, 1]

            fig, ax = plt.subplots()
            ax.imshow(img)
            ax.set_axis_off()

            if save_dir is None:
                save_dir = root_dir / 'evaluation' / cfg_file.stem
            if not save_dir.exists():
                save_dir.mkdir()
            file = save_dir / f'missclassfications_{cfg_file.stem}_{city}.png'

            plt.savefig(file, dpi=300, bbox_inches='tight')
            plt.close()
예제 #2
0
def model_evaluation(net, cfg, device, thresholds, run_type, epoch, step):

    thresholds = thresholds.to(device)
    y_true_set = []
    y_pred_set = []

    measurer = eval.MultiThresholdMetric(thresholds)

    dataset = datasets.OSCDDataset(cfg, run_type, no_augmentation=True)
    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER,
        'shuffle': cfg.DATALOADER.SHUFFLE,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    with torch.no_grad():
        net.eval()
        for step, batch in enumerate(dataloader):
            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)
            y_true = batch['label'].to(device)

            y_pred = net(t1_img, t2_img)

            y_pred = torch.sigmoid(y_pred)

            y_true = y_true.detach()
            y_pred = y_pred.detach()
            y_true_set.append(y_true.cpu())
            y_pred_set.append(y_pred.cpu())

            measurer.add_sample(y_true, y_pred)

    print(f'Computing {run_type} F1 score ', end=' ', flush=True)

    f1 = measurer.compute_f1()
    fpr, fnr = measurer.compute_basic_metrics()
    maxF1 = f1.max()
    argmaxF1 = f1.argmax()
    best_fpr = fpr[argmaxF1]
    best_fnr = fnr[argmaxF1]
    best_thresh = thresholds[argmaxF1]

    if not cfg.DEBUG:
        wandb.log({
            f'{run_type} max F1': maxF1,
            f'{run_type} argmax F1': argmaxF1,
            f'{run_type} false positive rate': best_fpr,
            f'{run_type} false negative rate': best_fnr,
            'step': step,
            'epoch': epoch,
        })

    print(f'{maxF1.item():.3f}', flush=True)

    return maxF1.item(), best_thresh.item()
예제 #3
0
def visualize_images(root_dir: Path, save_dir: Path = None):

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    cfg_file = Path.cwd() / 'configs' / 'optical_visualization.yaml'
    cfg = load_cfg(cfg_file)

    dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True)
    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0,
        'shuffle': False,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    with torch.no_grad():
        for step, batch in enumerate(dataloader):
            city = batch['city'][0]
            print(city)
            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)

            rgb_indices = [3, 2, 1]
            for i, img in enumerate([t1_img, t2_img]):
                fig, ax = plt.subplots()
                img = img.cpu().detach().numpy()[0, ]
                img = img.transpose((1, 2, 0))
                rgb = img[:, :, rgb_indices] / 0.3
                rgb = np.minimum(rgb, 1)
                ax.imshow(rgb)
                ax.set_axis_off()

                if save_dir is None:
                    save_dir = root_dir / 'evaluation' / 'images'
                if not save_dir.exists():
                    save_dir.mkdir()
                file = save_dir / f'{city}_img{i + 1}.png'

                plt.savefig(file, dpi=300, bbox_inches='tight')
                plt.close()
예제 #4
0
def visual_evaluation(root_dir: Path,
                      cfg_file: Path,
                      net_file: Path,
                      dataset: str = 'test',
                      n: int = 10,
                      save_dir: Path = None,
                      label_pred_only: bool = False):

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    # loading cfg and network
    cfg = load_cfg(cfg_file)
    net = load_net(cfg, net_file)

    # bands for visualizaiton
    s1_bands, s2_bands = cfg.DATASET.SENTINEL1_BANDS, cfg.DATASET.SENTINEL2_BANDS
    all_bands = s1_bands + s2_bands

    dataset = datasets.OSCDDataset(cfg, dataset, no_augmentation=True)
    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0,
        'shuffle': False,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    with torch.no_grad():
        net.eval()
        for step, batch in enumerate(dataloader):
            city = batch['city'][0]
            print(city)
            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)
            y_true = batch['label'].to(device)
            y_pred = net(t1_img, t2_img)
            y_pred = torch.sigmoid(y_pred)
            y_pred = y_pred.cpu().detach().numpy()[0, ]
            y_pred = y_pred > cfg.THRESH
            y_pred = y_pred.transpose((1, 2, 0)).astype('uint8')

            # label
            y_true = y_true.cpu().detach().numpy()[0, ]
            y_true = y_true.transpose((1, 2, 0)).astype('uint8')

            if label_pred_only:
                fig, axs = plt.subplots(1, 2, figsize=(10, 10))
                axs[0].imshow(y_true[:, :, 0])
                axs[1].imshow(y_pred[:, :, 0])
            else:
                fig, axs = plt.subplots(1, 4, figsize=(20, 10))
                rgb_indices = [
                    all_bands.index(band) for band in ('B04', 'B03', 'B02')
                ]
                for i, img in enumerate([t1_img, t2_img]):
                    img = img.cpu().detach().numpy()[0, ]
                    img = img.transpose((1, 2, 0))
                    rgb = img[:, :, rgb_indices] / 0.3
                    rgb = np.minimum(rgb, 1)
                    axs[i + 2].imshow(rgb)
                axs[0].imshow(y_true[:, :, 0])
                axs[1].imshow(y_pred[:, :, 0])

            for ax in axs:
                ax.set_axis_off()

            if save_dir is None:
                save_dir = root_dir / 'evaluation' / cfg_file.stem
            if not save_dir.exists():
                save_dir.mkdir()
            file = save_dir / f'eval_{cfg_file.stem}_{city}.png'

            plt.savefig(file, dpi=300, bbox_inches='tight')
            plt.close()
예제 #5
0
def numeric_evaluation(cfg_file: Path, net_file: Path):

    tta_thresholds = np.linspace(0, 1, 11)

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    # loading cfg and network
    cfg = load_cfg(cfg_file)
    net = load_net(cfg, net_file)
    dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True)

    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0,
        'shuffle': cfg.DATALOADER.SHUFFLE,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    def predict(t1, t2):
        pred = net(t1, t2)
        pred = torch.sigmoid(pred) > cfg.THRESH
        pred = pred.detach().float()
        return pred

    def evaluate(true, pred):
        f1_score = eval.f1_score(true.flatten(), pred.flatten(), dim=0).item()
        true_pos = eval.true_pos(true.flatten(), pred.flatten(), dim=0).item()
        false_pos = eval.false_pos(true.flatten(), pred.flatten(),
                                   dim=0).item()
        false_neg = eval.false_neg(true.flatten(), pred.flatten(),
                                   dim=0).item()
        return f1_score, true_pos, false_pos, false_neg

    cities, f1_scores, true_positives, false_positives, false_negatives = [], [], [], [], []
    tta = []
    with torch.no_grad():
        net.eval()
        for step, batch in enumerate(dataloader):

            city = batch['city'][0]
            print(city)
            cities.append(city)

            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)

            y_true = batch['label'].to(device)

            y_pred = predict(t1_img, t2_img)
            f1_score, tp, fp, fn = evaluate(y_true, y_pred)
            f1_scores.append(f1_score)
            true_positives.append(tp)
            false_positives.append(fp)
            false_negatives.append(fn)

            sum_preds = torch.zeros(y_true.shape).float().to(device)
            n_augs = 0

            # rotations
            for k in range(4):
                t1_img_rot = torch.rot90(t1_img, k, (2, 3))
                t2_img_rot = torch.rot90(t2_img, k, (2, 3))
                y_pred = predict(t1_img_rot, t2_img_rot)
                y_pred = torch.rot90(y_pred, 4 - k, (2, 3))

                sum_preds += y_pred
                n_augs += 1

            # flips
            for flip in [(2, 3), (3, 2)]:
                t1_img_flip = torch.flip(t1_img, flip)
                t2_img_flip = torch.flip(t1_img, flip)
                y_pred = predict(t1_img_flip, t2_img_flip)
                y_pred = torch.flip(y_pred, flip)

                sum_preds += y_pred
                n_augs += 1

            pred_tta = sum_preds.float() / n_augs
            tta_city = []
            for ts in tta_thresholds:
                y_pred = pred_tta > ts
                y_pred = y_pred.float()
                eval_ts = evaluate(y_true, y_pred)
                tta_city.append(eval_ts)
            tta.append(tta_city)

        precision = np.sum(true_positives) / (np.sum(true_positives) +
                                              np.sum(false_positives))
        recall = np.sum(true_positives) / (np.sum(true_positives) +
                                           np.sum(false_negatives))
        f1_score = 2 * (precision * recall / (precision + recall))
        print(
            f'precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1_score:.3f}'
        )

        tta_f1_scores = []
        for i, ts in enumerate(tta_thresholds):
            tta_ts = [city[i] for city in tta]
            tp = np.sum([eval_ts[1] for eval_ts in tta_ts])
            fp = np.sum([eval_ts[2] for eval_ts in tta_ts])
            fn = np.sum([eval_ts[3] for eval_ts in tta_ts])
            pre_tta = tp / (tp + fp + 1e-5)
            re_tta = tp / (tp + fn + 1e-5)
            f1_score_tta = 2 * (pre_tta * re_tta / (pre_tta + re_tta + 1e-5))
            tta_f1_scores.append(f1_score_tta)
            print(f'{ts:.2f}: {f1_score_tta:.3f}')

        fig, ax = plt.subplots()
        ax.plot(tta_thresholds, tta_f1_scores)
        ax.plot(tta_thresholds, [f1_score] * 11,
                label=f'without tta ({f1_score:.3f})')
        ax.legend()
        ax.set_xlabel('tta threshold (gt)')
        ax.set_ylabel('f1 score')
        ax.set_title(cfg_file.stem)
def visual_evaluation(net_cfg_file: Path,
                      net_file: Path,
                      ds_cfg_file: Path,
                      dataset: str = 'test',
                      save_path: Path = None):

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    # loading network
    net_cfg = load_cfg(net_cfg_file)
    net = load_net(net_cfg, net_file, device)

    # loading dataset
    ds_cfg = load_cfg(ds_cfg_file)
    dataset = datasets.OSCDDataset(ds_cfg, dataset, no_augmentation=True)
    dataloader_kwargs = {
        'batch_size': 1,
        'num_workers': 0,
        'shuffle': False,
        'pin_memory': True,
    }
    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)
    threshold = net_cfg.THRESH

    with torch.no_grad():
        net.eval()
        for step, batch in enumerate(dataloader):

            fig, axs = plt.subplots(1, 4, figsize=(20, 10))

            city = batch['city'][0]
            print(city)

            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)
            y_true = batch['label'].to(device)

            data = {'pred': [], 'prob': [], 'rgb': []}
            for i, img in enumerate([t1_img, t2_img]):
                img_arr = torch2numpy(img)
                y_pred, y_prob = classify(img,
                                          net,
                                          threshold,
                                          return_numpy=True)
                data['pred'].append(y_pred[0, :, :, 0])
                data['prob'].append(y_prob[0, :, :, 0])

                img_arr = img_arr[0, ...]
                rgb = img_arr[:, :, [2, 1, 0]]
                rgb = np.minimum(rgb / 0.3, 1)
                data['rgb'].append(rgb)

                axs[i].imshow(y_prob[0, :, :, 0], vmin=0, vmax=1)
                # axs[i*2+1].imshow(y_prob[0, :, :, 0])

            label_arr = torch2numpy(y_true, 'uint8')
            axs[2].imshow(label_arr[0, :, :, 0])
            di = data['prob'][1] - data['prob'][0]
            axs[3].imshow(di, vmin=0, vmax=1)

            for ax in axs:
                ax.set_axis_off()

            assert (save_path.exists())
            file = save_path / f'urban_extraction_{city}.png'
            plt.savefig(file, dpi=300, bbox_inches='tight')
            plt.close()
예제 #7
0
def train(net, cfg):

    # setting device on GPU if available, else CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    net.to(device)

    if cfg.TRAINER.OPTIMIZER == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=cfg.TRAINER.LR, weight_decay=0.0005)
    else:
        optimizer = torch.optim.SGD(net.parameters(), lr=cfg.TRAINER.LR, momentum=0.9)

    # loss functions
    if cfg.MODEL.LOSS_TYPE == 'BCEWithLogitsLoss':
        criterion = torch.nn.BCEWithLogitsLoss()
    elif cfg.MODEL.LOSS_TYPE == 'WeightedBCEWithLogitsLoss':
        positive_weight = torch.tensor([cfg.MODEL.POSITIVE_WEIGHT]).float().to(device)
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weight)
    elif cfg.MODEL.LOSS_TYPE == 'SoftDiceLoss':
        criterion = lf.soft_dice_loss
    elif cfg.MODEL.LOSS_TYPE == 'SoftDiceBalancedLoss':
        criterion = lf.soft_dice_loss_balanced
    elif cfg.MODEL.LOSS_TYPE == 'JaccardLikeLoss':
        criterion = lf.jaccard_like_loss
    elif cfg.MODEL.LOSS_TYPE == 'ComboLoss':
        criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts) + lf.soft_dice_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'WeightedComboLoss':
        criterion = lambda pred, gts: 2 * F.binary_cross_entropy_with_logits(pred, gts) + lf.soft_dice_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'FrankensteinLoss':
        criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts) + lf.jaccard_like_balanced_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'WeightedFrankensteinLoss':
        positive_weight = torch.tensor([cfg.MODEL.POSITIVE_WEIGHT]).float().to(device)
        criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts, pos_weight=positive_weight) + 5 * lf.jaccard_like_balanced_loss(pred, gts)
    else:
        criterion = lf.soft_dice_loss

    # reset the generators
    dataset = datasets.OSCDDataset(cfg, 'train')
    drop_last = True
    batch_size = cfg.TRAINER.BATCH_SIZE
    dataloader_kwargs = {
        'batch_size': batch_size,
        'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER,
        'shuffle': cfg.DATALOADER.SHUFFLE,
        'drop_last': drop_last,
        'pin_memory': True,
    }
    if cfg.AUGMENTATION.OVERSAMPLING != 'none':
        dataloader_kwargs['sampler'] = dataset.sampler()
        dataloader_kwargs['shuffle'] = False

    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    save_path = Path(cfg.OUTPUT_BASE_DIR) / cfg.NAME
    save_path.mkdir(exist_ok=True)

    best_test_f1 = 0
    positive_pixels = 0
    pixels = 0
    global_step = 0
    epochs = cfg.TRAINER.EPOCHS
    batches = len(dataloader) // batch_size if drop_last else len(dataloader) // batch_size + 1
    for epoch in range(epochs):

        loss_tracker = 0
        net.train()

        for i, batch in enumerate(dataloader):

            t1_img = batch['t1_img'].to(device)
            t2_img = batch['t2_img'].to(device)

            label = batch['label'].to(device)

            optimizer.zero_grad()

            output = net(t1_img, t2_img)

            loss = criterion(output, label)
            loss_tracker += loss.item()
            loss.backward()
            optimizer.step()

            positive_pixels += torch.sum(label).item()
            pixels += torch.numel(label)

            global_step += 1

        if epoch % cfg.LOGGING == 0:
            print(f'epoch {epoch} / {cfg.TRAINER.EPOCHS}')

            # printing and logging loss
            avg_loss = loss_tracker / batches
            print(f'avg training loss {avg_loss:.5f}')

            # positive pixel ratio used to check oversampling
            if cfg.DEBUG:
                print(f'positive pixel ratio: {positive_pixels / pixels:.3f}')
            else:
                wandb.log({f'positive pixel ratio': positive_pixels / pixels})
            positive_pixels = 0
            pixels = 0

            # model evaluation
            # train (different thresholds are tested)
            train_thresholds = torch.linspace(0, 1, 101).to(device)
            train_maxF1, train_maxTresh = model_evaluation(net, cfg, device, train_thresholds, run_type='train',
                                                           epoch=epoch, step=global_step)
            # test (using the best training threshold)
            test_threshold = torch.tensor([train_maxTresh])
            test_f1, _ = model_evaluation(net, cfg, device, test_threshold, run_type='test', epoch=epoch,
                                          step=global_step)

            if test_f1 > best_test_f1:
                print(f'BEST PERFORMANCE SO FAR! <--------------------', flush=True)
                best_test_f1 = test_f1

                if cfg.SAVE_MODEL and not cfg.DEBUG:
                    print(f'saving network', flush=True)
                    # model_file = save_path / 'best_net.pkl'
                    # torch.save(net.state_dict(), model_file)

            if (epoch + 1) == 390:
                if cfg.SAVE_MODEL and not cfg.DEBUG:
                    print(f'saving network', flush=True)
                    model_file = save_path / f'final_net.pkl'
                    torch.save(net.state_dict(), model_file)