def main():
    args = parse_args()

    set_reproducibility(args.seed)

    resize_me = resize_aug(imsize_x=args.imsize_x, imsize_y=args.imsize_y)
    pad_func = partial(pad_image, ratio=args.imratio)

    print("Testing the PNEUMOTHORAX SEGMENTATION model...")

    torch.cuda.set_device(args.gpu)
    torch.backends.cudnn.benchmark = True

    if not os.path.exists(os.path.dirname(args.save_file)):
        os.makedirs(os.path.dirname(args.save_file))

    print("Reading labels from {}".format(args.labels_df))

    df = pd.read_csv(args.labels_df)

    if args.outer_only:
        # Test on outer fold
        test_df = df[df['outer'] == args.outer_fold]
    else:
        # Get rid of outer fold test set
        df = df[df['outer'] != args.outer_fold]
        assert np.sum(df['inner{}'.format(args.outer_fold)] == 888) == 0
        test_df = df[df['inner{}'.format(args.outer_fold)] == args.inner_fold]

    print('TEST: n={}'.format(len(test_df)))

    print("Reading images from directory {}".format(args.data_dir))
    test_images = [
        os.path.join(args.data_dir, '{}.png'.format(_)) for _ in test_df['sop']
    ]
    test_masks = [
        os.path.join(args.mask_dir, '{}.png'.format(_)) for _ in test_df['sop']
    ]
    test_labels = list(test_df['ptx_binary'])
    test_sops = list(test_df['sop'])
    num_classes = 2

    # Get models in snapshot ensemble
    snapshots = glob.glob(os.path.join(args.model_folder, '*.pth'))

    num_snapshots = args.num_snapshots
    snapshot_weights = args.ss_weights
    # Pick best 3 models, then weight based on Kaggle metric: 3, 1, 1
    # This assumes a certain formatting of the checkpoint file name
    # in order to extract the Kaggle metric
    if args.class_mode:

        def extract_kag(ckpt):
            ckpt = ckpt.split('/')[-1]
            _kag = ckpt.split('_')[4]
            _kag = _kag.split('-')[-1]
            return float(_kag)
    elif args.pos_only:

        def extract_kag(ckpt):
            ckpt = ckpt.split('/')[-1]
            _kag = ckpt.split('_')[2]
            _kag = _kag.split('-')[-1]
            return float(_kag)
    else:

        def extract_kag(ckpt):
            ckpt = ckpt.split('/')[-1]
            _kag = ckpt.split('_')[6]
            _kag = _kag.split('-')[-1]
            return float(_kag)

    snapshot_kags = [extract_kag(_) for _ in snapshots]
    kag_order = np.argsort(snapshot_kags)[::-1][:num_snapshots]
    snapshots = list(np.asarray(snapshots)[kag_order])

    def load_model(ckpt):
        model = DeepLab(args.model,
                        args.output_stride,
                        args.gn,
                        center=args.center,
                        jpu=args.jpu,
                        use_maxpool=not args.no_maxpool)
        model.load_state_dict(torch.load(ckpt))
        model = model.cuda()
        model.eval()
        return model

    # Get models
    print('Loading checkpoints ...')
    model_list = []
    for ss in snapshots:
        model_list.append(load_model(ss))

    # Set up preprocessing function with model
    ppi = partial(preprocess_input, model=model_list[0])

    print('Setting up data loaders ...')

    params = {
        'batch_size': 1 if args.tta else args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }

    test_set = XrayMaskDataset(imgfiles=test_images,
                               maskfiles=test_masks,
                               dicom=False,
                               labels=test_labels,
                               preprocess=ppi,
                               pad=pad_func,
                               crop=None,
                               resize=resize_me,
                               test_mode=True)
    test_gen = DataLoader(test_set, **params)

    # Test
    def get_test_predictions(mod):
        with torch.no_grad():
            list_of_pred_dicts = []
            for data in tqdm(test_gen, total=len(test_gen)):
                pred_dict = {}
                if args.tta:
                    # should be batch size = 1
                    batch, masks, classes = data
                    batch = batch[0]
                    output = mod(batch.cuda())
                    pred_dict['pred_mask'] = torch.softmax(
                        output, dim=1).cpu().numpy()[:, 1]
                    pred_dict['gt_mask'] = masks.cpu().numpy().astype('uint8')
                    pred_dict['y_true'] = classes.cpu().numpy()
                else:
                    batch, masks, classes = data
                    output = mod(batch.cuda())
                    output_flipped = mod(torch.flip(batch, dims=(-1, )).cuda())
                    output_flipped = torch.flip(output_flipped, dims=(-1, ))
                    pred_dict['pred_mask'] = (torch.softmax(
                        output, dim=1).cpu().numpy()[:, 1] + torch.softmax(
                            output_flipped, dim=1).cpu().numpy()[:, 1]) / 2.
                    pred_dict['gt_mask'] = masks.cpu().numpy().astype('uint8')
                    pred_dict['y_true'] = classes.cpu().numpy()
                list_of_pred_dicts.append(pred_dict)
        return list_of_pred_dicts

    y_pred_list = []
    for model in tqdm(model_list, total=len(model_list)):
        tmp_y_pred = get_test_predictions(model)
        y_pred_list.append(tmp_y_pred)

    # Need to average predictions across models
    for each_indiv_pred in range(len(y_pred_list[0])):
        indiv_pred = np.zeros_like(
            y_pred_list[0][each_indiv_pred]['pred_mask'])
        for each_model_pred in range(len(y_pred_list)):
            indiv_pred += snapshot_weights[each_model_pred] * y_pred_list[
                each_model_pred][each_indiv_pred]['pred_mask']
        indiv_pred /= float(np.sum(snapshot_weights))
        assert np.min(indiv_pred) >= 0 and np.max(indiv_pred) <= 1
        y_pred_list[0][each_indiv_pred]['pred_mask'] = (indiv_pred *
                                                        100).astype('uint8')

    def get_top_X(segmentation, tops=[0, 0.5, 1.0, 2.5, 5.0]):
        # Assumes segmentation.shape is (1, H, W)
        assert segmentation.shape[0] == 1
        scores = []
        segmentation = segmentation.reshape(segmentation.shape[0],
                                            -1).astype('int8')
        segmentation = -np.sort(-segmentation, axis=1)
        for t in tops:
            size = int(t / 100. * np.prod(segmentation.shape)) if t > 0 else 1
            scores.append(np.mean(segmentation[:, :size]) / 100.)
        return scores

    if args.class_mode:
        # Turn segmentation output into class scores
        tops = [0, 0.5, 1.0, 2.5, 5.0]
        class_scores = []
        for i in range(len(y_pred_list[0])):
            class_scores.append(get_top_X(y_pred_list[0][i]['pred_mask'],
                                          tops))
        # Make a DataFrame
        class_scores = np.vstack(class_scores)
        class_scores = pd.DataFrame(class_scores)
        class_scores.columns = ['Top{}'.format(t) for t in tops]
        class_scores['y_true'] = [_['y_true'][0] for _ in y_pred_list[0]]
        class_scores['sop'] = test_sops
        class_scores.to_csv(args.save_file, index=False)
    else:
        y_pred_to_pickle = y_pred_list[0]
        y_pred_to_pickle = {
            test_sops[_]: y_pred_to_pickle[_]
            for _ in range(len(test_sops))
        }

        with open(args.save_file, 'wb') as f:
            pickle.dump(y_pred_to_pickle, f)
Ejemplo n.º 2
0
def main():
    args = parse_args()

    set_reproducibility(args.seed)

    train_aug = simple_aug(p=args.augment_p)
    resize_me = resize_aug(imsize_x=args.imsize_x, imsize_y=args.imsize_y)
    pad_func = partial(pad_image, ratio=args.imratio)

    print("Training the PNEUMOTHORAX SEGMENTATION model...")

    torch.cuda.set_device(args.gpu)
    torch.backends.cudnn.benchmark = True

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    print("Saving model to {}".format(args.save_dir))
    print("Reading labels from {}".format(args.labels_df))

    df = pd.read_csv(args.labels_df)
    if args.pos_only:
        df = df[df['ptx_binary'] == 1]

    if args.outer_only:
        # We may want to only use outer splits
        train_df = df[df['outer'] != args.outer_fold]
        valid_df = df[df['outer'] == args.outer_fold]
    else:
        # Get rid of outer fold test set
        df = df[df['outer'] != args.outer_fold]
        assert np.sum(df['inner{}'.format(args.outer_fold)] == 888) == 0
        train_df = df[df['inner{}'.format(args.outer_fold)] != args.inner_fold]
        valid_df = df[df['inner{}'.format(args.outer_fold)] == args.inner_fold]

    print('TRAIN: n={}'.format(len(train_df)))
    print('% PTX: {:.1f}'.format(np.mean(train_df['ptx_binary']) * 100))
    print('VALID: n={}'.format(len(valid_df)))
    print('% PTX: {:.1f}'.format(np.mean(valid_df['ptx_binary']) * 100))

    print("Reading images from directory {}".format(args.data_dir))
    train_images = [
        os.path.join(args.data_dir, '{}.png'.format(_))
        for i, _ in enumerate(train_df['sop'])
    ]
    pos_train_images = [
        os.path.join(args.data_dir, '{}.png'.format(_))
        for i, _ in enumerate(train_df['sop'])
        if train_df['ptx_binary'].iloc[i] == 1
    ]
    neg_train_images = [
        os.path.join(args.data_dir, '{}.png'.format(_))
        for i, _ in enumerate(train_df['sop'])
        if train_df['ptx_binary'].iloc[i] == 0
    ]
    train_labels = list(train_df['ptx_binary'])

    valid_images = [
        os.path.join(args.data_dir, '{}.png'.format(_))
        for _ in valid_df['sop']
    ]
    valid_labels = list(valid_df['ptx_binary'])

    print("Reading masks from directory {}".format(args.mask_dir))
    train_masks = [
        os.path.join(args.mask_dir, '{}.png'.format(_))
        for i, _ in enumerate(train_df['sop'])
    ]
    pos_train_masks = [
        os.path.join(args.mask_dir, '{}.png'.format(_))
        for i, _ in enumerate(train_df['sop'])
        if train_df['ptx_binary'].iloc[i] == 1
    ]
    valid_masks = [
        os.path.join(args.mask_dir, '{}.png'.format(_))
        for _ in valid_df['sop']
    ]

    model = DeepLab(args.model,
                    args.output_stride,
                    args.gn,
                    center=args.center,
                    jpu=args.jpu,
                    use_maxpool=not args.no_maxpool)
    if args.load_model != '':
        print('Loading trained model {} ...'.format(args.load_model))
        model.load_state_dict(torch.load(args.load_model))
    model = model.cuda()
    model.train()

    if args.loss == 'lovasz_softmax':
        criterion = LL.LovaszSoftmax().cuda()
    elif args.loss == 'soft_dice':
        criterion = SoftDiceLoss().cuda()
    elif args.loss == 'soft_dicev2':
        criterion = SoftDiceLossV2().cuda()
    elif args.loss == 'dice_bce':
        criterion = DiceBCELoss().cuda()
    elif args.loss == 'lovasz_hinge':
        criterion = LL.LovaszHinge().cuda()
    elif args.loss == 'weighted_bce':
        criterion = WeightedBCE(pos_frac=args.pos_frac,
                                neg_frac=args.neg_frac).cuda()
    elif args.loss == 'weighted_bce_v2':
        criterion = WeightedBCEv2().cuda()
    elif args.loss == 'focal_loss':
        criterion = FocalLoss().cuda()

    train_params = model.parameters()

    if args.optimizer.lower() == 'adam':
        optimizer = optim.Adam(train_params,
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(train_params,
                              lr=args.initial_lr,
                              weight_decay=args.weight_decay,
                              momentum=args.momentum,
                              nesterov=args.nesterov)
    elif args.optimizer.lower() == 'adabound':
        optimizer = adabound.AdaBound(train_params,
                                      lr=args.initial_lr,
                                      final_lr=args.initial_lr *
                                      args.final_lr_scale,
                                      weight_decay=args.weight_decay,
                                      gamma=args.gamma)
    else:
        '`{}` is not a valid optimizer .'.format(args.optimizer)

    if APEX_AVAILABLE and args.mixed:
        print('Using NVIDIA Apex for mixed precision training ...')
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O2",
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")

    if not isinstance(optimizer, Optimizer):
        flag = False
        try:
            from apex.fp16_utils.fp16_optimizer import FP16_Optimizer
            if isinstance(optimizer, FP16_Optimizer):
                flag = True
        except ModuleNotFoundError:
            pass
        if not flag:
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))

    if args.cosine_anneal:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=int(args.total_epochs / args.num_snapshots),
            eta_min=args.eta_min)
        scheduler.T_cur = 0.
        scheduler.mode = 'max'
        scheduler.threshold = args.min_delta
    else:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            'max',
            factor=args.annealing_factor,
            patience=args.lr_patience,
            threshold=args.min_delta,
            threshold_mode='abs',
            verbose=True)

    # Set up preprocessing function with model
    ppi = partial(preprocess_input, model=model)

    print('Setting up data loaders ...')

    params = {
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.num_workers,
        'drop_last': True
    }

    valid_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }

    if args.balanced:
        train_set = XrayEqualMaskDataset(posfiles=pos_train_images,
                                         negfiles=neg_train_images,
                                         maskfiles=pos_train_masks,
                                         dicom=False,
                                         labels=None,
                                         preprocess=ppi,
                                         transform=train_aug,
                                         pad=pad_func,
                                         resize=resize_me,
                                         inversion=args.invert)
    else:
        train_set = XrayMaskDataset(imgfiles=train_images,
                                    maskfiles=train_masks,
                                    dicom=False,
                                    labels=train_labels,
                                    preprocess=ppi,
                                    transform=train_aug,
                                    pad=pad_func,
                                    resize=resize_me,
                                    inversion=args.invert)

    if args.pos_neg_ratio > 0:
        params['shuffle'] = False
        params['sampler'] = RatioSampler(train_set, args.num_samples,
                                         args.pos_neg_ratio)

    train_gen = DataLoader(train_set, **params)

    valid_set = XrayMaskDataset(imgfiles=valid_images,
                                maskfiles=valid_masks,
                                dicom=False,
                                labels=valid_labels,
                                preprocess=ppi,
                                pad=pad_func,
                                resize=resize_me,
                                test_mode=True,
                                inversion=args.invert)
    valid_gen = DataLoader(valid_set, **valid_params)

    loss_tracker = LossTracker()

    steps_per_epoch = args.steps_per_epoch
    if steps_per_epoch == 0:
        if args.grad_accum == 0:
            effective_batch_size = args.batch_size
        elif args.grad_accum > 0:
            effective_batch_size = args.batch_size * args.grad_accum
        else:
            raise Exception('`grad-accum` cannot be negative')
        if args.balanced:
            effective_batch_size *= 2
            # Hack for steps_per_epoch calculation
            train_set.imgfiles = train_set.negfiles
        steps_per_epoch = int(
            np.ceil(len(train_set.imgfiles) / effective_batch_size))
        if args.pos_neg_ratio > 0:
            steps_per_epoch = int(
                np.ceil(args.num_samples / effective_batch_size))

    if args.pos_only and args.balanced:
        raise Exception('`pos-only` and `balanced` cannot both be specified')

    trainer_class = Trainer if args.pos_only else AllTrainer
    if args.balanced:
        trainer_class = BalancedTrainer
    trainer = trainer_class(model,
                            'DeepLab',
                            optimizer,
                            criterion,
                            loss_tracker,
                            args.save_dir,
                            args.save_best,
                            multiclass=train_set.multiclass)
    #if args.pos_neg_ratio > 0:
    #    trainer.track_valid_metric = 'pos_dsc'
    trainer.grad_accum = args.grad_accum
    if APEX_AVAILABLE and args.mixed:
        trainer.use_amp = True
    trainer.set_dataloaders(train_gen, valid_gen)
    trainer.set_thresholds(args.thresholds)

    if args.train_head > 0:
        trainer.train_head(optim.Adam(classifier.parameters()),
                           steps_per_epoch, args.train_head)

    trainer.train(args.total_epochs,
                  steps_per_epoch,
                  scheduler,
                  args.stop_patience,
                  verbosity=args.verbosity)