Ejemplo n.º 1
0
def inference_loop2(net,
                    cfg,
                    device,
                    callback=None,
                    run_type='TEST',
                    max_samples=999999999,
                    dataset=None):

    net.to(device)
    net.eval()

    # reset the generators
    if dataset is None:
        dataset = Xview2Detectron2Dataset(dset_source, 0, cfg)
    dataloader = torch_data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        num_workers=cfg.DATALOADER.NUM_WORKER,
        shuffle=False,
        drop_last=False,
    )

    dataset_length = np.minimum(len(dataset), max_samples)
    with torch.no_grad():
        for step, batch in enumerate(dataloader):

            imgs = batch['x'].to(device)
            y_label = batch['y'].to(device)
            sample_name = batch['img_name']
            index = batch['index']

            y_pred = net(imgs)

            if step % 10 == 0 or step == dataset_length - 1:
                print(
                    f'Processed {step+1}/{dataset_length}',
                    f', max cuda usage: {torch.cuda.max_memory_allocated() / 1e6 :.2f} MB',
                    flush=True)

            if cfg.MODEL.LOSS_TYPE == 'CrossEntropyLoss':
                # In Two class Cross entropy mode, positive classes are in Channel #2
                y_pred = torch.softmax(y_pred, dim=1)
                y_pred = y_pred[:, 1, ...]
                y_pred = y_pred[:, None, ...]
            else:
                y_pred = torch.sigmoid(y_pred)

            callback(imgs, y_label, y_pred, sample_name, index)

            if (max_samples is not None) and step >= max_samples:
                break
Ejemplo n.º 2
0
    def train_dataloader(self):
        # REQUIRED
        cfg = self.cfg
        use_edge_loss = cfg.MODEL.LOSS_TYPE == 'FrankensteinEdgeLoss'
        trfm = []
        trfm.append(BGR2RGB())
        if cfg.DATASETS.USE_CLAHE_VARI: trfm.append(VARI())
        if cfg.AUGMENTATION.RESIZE:
            trfm.append(Resize(scale=cfg.AUGMENTATION.RESIZE_RATIO))
        if cfg.AUGMENTATION.CROP_TYPE == 'uniform':
            trfm.append(UniformCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE))
        elif cfg.AUGMENTATION.CROP_TYPE == 'importance':
            trfm.append(
                ImportanceRandomCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE))
        if cfg.AUGMENTATION.RANDOM_FLIP_ROTATE: trfm.append(RandomFlipRotate())

        trfm.append(Npy2Torch())
        trfm = transforms.Compose(trfm)

        dataset = Xview2Detectron2Dataset(
            cfg.DATASETS.TRAIN[0],
            pre_or_post=cfg.DATASETS.PRE_OR_POST,
            include_image_weight=True,
            transform=trfm,
            include_edge_mask=use_edge_loss,
            use_clahe=cfg.DATASETS.USE_CLAHE_VARI,
        )

        dataloader_kwargs = {
            'batch_size': cfg.TRAINER.BATCH_SIZE,
            'num_workers': cfg.DATALOADER.NUM_WORKER,
            'shuffle': cfg.DATALOADER.SHUFFLE,
            'drop_last': True,
            'pin_memory': True,
        }
        # sampler
        if cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'simple':
            image_p = self.image_sampling_weight(dataset.dataset_metadata)
            sampler = torch_data.WeightedRandomSampler(
                weights=image_p, num_samples=len(image_p))
            dataloader_kwargs['sampler'] = sampler
            dataloader_kwargs['shuffle'] = False
        dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)
        return dataloader
Ejemplo n.º 3
0
def train_net(net, cfg):

    log_path = cfg.OUTPUT_DIR
    writer = SummaryWriter(log_path)

    run_config = {}
    run_config['CONFIG_NAME'] = cfg.NAME
    run_config['device'] = device
    run_config['log_path'] = cfg.OUTPUT_DIR
    run_config['training_set'] = cfg.DATASETS.TRAIN
    run_config['test set'] = cfg.DATASETS.TEST
    run_config['epochs'] = cfg.TRAINER.EPOCHS
    run_config['learning rate'] = cfg.TRAINER.LR
    run_config['batch size'] = cfg.TRAINER.BATCH_SIZE
    table = {
        'run config name': run_config.keys(),
        ' ': run_config.values(),
    }
    print(tabulate(
        table,
        headers='keys',
        tablefmt="fancy_grid",
    ))

    optimizer = optim.Adam(net.parameters(),
                           lr=cfg.TRAINER.LR,
                           weight_decay=0.0005)
    if cfg.MODEL.LOSS_TYPE == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    elif cfg.MODEL.LOSS_TYPE == 'CrossEntropyLoss':
        balance_weight = [cfg.MODEL.NEGATIVE_WEIGHT, cfg.MODEL.POSITIVE_WEIGHT]
        balance_weight = torch.tensor(balance_weight).float().to(device)
        criterion = nn.CrossEntropyLoss(weight=balance_weight)
    elif cfg.MODEL.LOSS_TYPE == 'SoftDiceLoss':
        criterion = soft_dice_loss
    elif cfg.MODEL.LOSS_TYPE == 'SoftDiceBalancedLoss':
        criterion = soft_dice_loss_balanced
    elif cfg.MODEL.LOSS_TYPE == 'JaccardLikeLoss':
        criterion = jaccard_like_loss
    elif cfg.MODEL.LOSS_TYPE == 'ComboLoss':
        criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(
            pred, gts) + soft_dice_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'WeightedComboLoss':
        criterion = lambda pred, gts: 2 * F.binary_cross_entropy_with_logits(
            pred, gts) + soft_dice_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'FrankensteinLoss':
        criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(
            pred, gts) + jaccard_like_balanced_loss(pred, gts)
    elif cfg.MODEL.LOSS_TYPE == 'FrankensteinEdgeLoss':
        criterion = frankenstein_edge_loss

    if torch.cuda.device_count() > 1:
        print(torch.cuda.device_count(), " GPUs!")
        net = nn.DataParallel(net)
    net.to(device)
    global_step = 0
    epochs = cfg.TRAINER.EPOCHS

    use_edge_loss = cfg.MODEL.LOSS_TYPE == 'FrankensteinEdgeLoss'

    for name, _ in net.named_parameters():
        print(name)

    trfm = []
    trfm.append(BGR2RGB())
    if cfg.DATASETS.USE_CLAHE_VARI: trfm.append(VARI())
    if cfg.AUGMENTATION.RESIZE:
        trfm.append(Resize(scale=cfg.AUGMENTATION.RESIZE_RATIO))
    if cfg.AUGMENTATION.CROP_TYPE == 'uniform':
        trfm.append(UniformCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE))
    elif cfg.AUGMENTATION.CROP_TYPE == 'importance':
        trfm.append(ImportanceRandomCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE))
    if cfg.AUGMENTATION.RANDOM_FLIP_ROTATE: trfm.append(RandomFlipRotate())

    trfm.append(Npy2Torch())
    trfm = transforms.Compose(trfm)

    # reset the generators
    dataset = Xview2Detectron2Dataset(
        cfg.DATASETS.TRAIN[0],
        pre_or_post=cfg.DATASETS.PRE_OR_POST,
        include_image_weight=True,
        transform=trfm,
        include_edge_mask=use_edge_loss,
        edge_mask_type=cfg.MODEL.EDGE_WEIGHTED_LOSS.TYPE,
        use_clahe=cfg.DATASETS.USE_CLAHE_VARI,
    )

    dataloader_kwargs = {
        'batch_size': cfg.TRAINER.BATCH_SIZE,
        'num_workers': cfg.DATALOADER.NUM_WORKER,
        'shuffle': cfg.DATALOADER.SHUFFLE,
        'drop_last': True,
        'pin_memory': True,
    }

    # sampler
    if cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'simple':
        image_p = image_sampling_weight(dataset.dataset_metadata)
        sampler = torch_data.WeightedRandomSampler(weights=image_p,
                                                   num_samples=len(image_p))
        dataloader_kwargs['sampler'] = sampler
        dataloader_kwargs['shuffle'] = False

    dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)

    for epoch in range(epochs):
        start = timeit.default_timer()
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        epoch_loss = 0

        net.train()
        # mean AP, mean AUC, max F1
        mAP_set_train, mAUC_set_train, maxF1_train = [], [], []
        loss_set, f1_set = [], []
        positive_pixels_set = [
        ]  # Used to evaluated image over sampling techniques
        for i, batch in enumerate(dataloader):
            optimizer.zero_grad()

            x = batch['x'].to(device)
            y_gts = batch['y'].to(device)
            image_weight = batch['image_weight']

            y_pred = net(x)

            if cfg.MODEL.LOSS_TYPE == 'CrossEntropyLoss':
                # y_pred = y_pred # Cross entropy loss doesn't like single channel dimension
                y_gts = y_gts.long(
                )  # Cross entropy loss requires a long as target

            if use_edge_loss:
                edge_mask = y_gts[:, [0]]
                y_gts = y_gts[:, 1:]
                edge_loss_scale = edge_loss_warmup_schedule(cfg, global_step)
                loss, ce_loss, jaccard_loss, edge_loss = criterion(
                    y_pred, y_gts, edge_mask, edge_loss_scale)
                wandb.log({
                    'ce_loss': ce_loss,
                    'jaccard_loss': jaccard_loss,
                    'edge_loss': edge_loss,
                    'step': global_step,
                    'edge_loss_scale': edge_loss_scale,
                })
            else:
                loss = criterion(y_pred, y_gts)

            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

            loss_set.append(loss.item())
            positive_pixels_set.extend(image_weight.cpu().numpy())

            if global_step % 100 == 0 or global_step == 0:
                # time per 100 steps
                stop = timeit.default_timer()
                time_per_n_batches = stop - start

                if global_step % 10000 == 0 and global_step > 0:
                    check_point_name = f'cp_{global_step}.pkl'
                    save_path = os.path.join(log_path, check_point_name)
                    torch.save(net.state_dict(), save_path)

                # Averaged loss and f1 writer

                # writer.add_scalar('f1/train', np.mean(f1_set), global_step)

                max_mem, max_cache = gpu_stats()
                print(
                    f'step {global_step},  avg loss: {np.mean(loss_set):.4f}, cuda mem: {max_mem} MB, cuda cache: {max_cache} MB, time: {time_per_n_batches:.2f}s',
                    flush=True)

                wandb.log({
                    'loss': np.mean(loss_set),
                    'gpu_memory': max_mem,
                    'time': time_per_n_batches,
                    'total_positive_pixels': np.mean(positive_pixels_set),
                    'step': global_step,
                })

                loss_set = []
                positive_pixels_set = []

                start = stop

            # torch.cuda.empty_cache()
            global_step += 1

        if epoch % 2 == 0:
            # Evaluation after every other epoch
            model_eval(net,
                       cfg,
                       device,
                       max_samples=100,
                       step=global_step,
                       epoch=epoch)
            model_eval(net,
                       cfg,
                       device,
                       max_samples=100,
                       run_type='TRAIN',
                       step=global_step,
                       epoch=epoch)
Ejemplo n.º 4
0
def inference_loop(
    net,
    cfg,
    device,
    callback=None,
    batch_size=1,
    run_type='TEST',
    max_samples=999999999,
    dataset=None,
    callback_include_x=False,
):

    net.to(device)
    net.eval()

    # reset the generators

    dset_source = cfg.DATASETS.TEST[
        0] if run_type == 'TEST' else cfg.DATASETS.TRAIN[0]
    if dataset is None:
        trfm = []
        if cfg.AUGMENTATION.RESIZE:
            trfm.append(Resize(scale=cfg.AUGMENTATION.RESIZE_RATIO))
        trfm.append(BGR2RGB())
        if cfg.DATASETS.USE_CLAHE_VARI: trfm.append(VARI())
        if cfg.MODEL.IN_CHANNELS == 4:
            trfm.append(AddCanny())
        trfm.append(Npy2Torch())
        trfm = transforms.Compose(trfm)

        dataset = Xview2Detectron2Dataset(dset_source,
                                          pre_or_post=cfg.DATASETS.PRE_OR_POST,
                                          transform=trfm)

    dataloader = torch_data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=cfg.DATALOADER.NUM_WORKER,
        shuffle=cfg.DATALOADER.SHUFFLE,
        drop_last=True,
    )

    dlen = len(dataset)
    dataset_length = np.minimum(len(dataset), max_samples)
    with torch.no_grad():
        for step, batch in enumerate(dataloader):
            imgs = batch['x'].to(device)
            y_label = batch['y'].to(device)
            sample_name = batch['img_name']

            y_pred = net(imgs)

            if step % 100 == 0 or step == dataset_length - 1:
                print(f'Processed {step+1}/{dataset_length}')

            if y_pred.shape[1] > 1:  # multi-class
                # In Two class Cross entropy mode, positive classes are in Channel #2
                y_pred = torch.softmax(y_pred, dim=1)
            else:
                y_pred = torch.sigmoid(y_pred)

            if callback:
                if callback_include_x:
                    callback(imgs, y_label, y_pred, sample_name)
                else:
                    callback(y_label, y_pred, sample_name)

            if (max_samples is not None) and step >= max_samples:
                break
Ejemplo n.º 5
0
# Per image  ===========

print('================= Running ablation per image ===============',
      flush=True)

trfm = []
if cfg.AUGMENTATION.RESIZE:
    trfm.append(Resize(scale=cfg.AUGMENTATION.RESIZE_RATIO,
                       resize_label=False))
trfm.append(BGR2RGB())
trfm.append(Npy2Torch())
trfm = transforms.Compose(trfm)

dataset = Xview2Detectron2Dataset(dset_source,
                                  pre_or_post=cfg.DATASETS.PRE_OR_POST,
                                  include_index=True,
                                  transform=trfm)
results_table = []


def compute_sample(x, Y_true, Y_pred, img_filenames, indices):
    # interp image if scaling was originally enabled
    if cfg.AUGMENTATION.RESIZE:
        upscale_ratio = 1 / cfg.AUGMENTATION.RESIZE_RATIO
        Y_pred = torch.nn.functional.interpolate(Y_pred,
                                                 scale_factor=upscale_ratio,
                                                 mode='bilinear')
    # expand batch
    Y_pred = Y_pred.squeeze(
        1) >= THRESHOLD  # remove empty channel and activate Y_pred
    Y_true = Y_true.squeeze(1).type(torch.bool)