Exemple #1
0
def load_data(PATH_TO_IMAGES, LABEL, PATH_TO_MODEL, POSITIVE_FINDINGS_ONLY,
              STARTER_IMAGES):
    """
    Loads dataloader and torchvision model

    Args:
        PATH_TO_IMAGES: path to NIH CXR images
        LABEL: finding of interest (must exactly match one of FINDINGS defined below or will get error)
        PATH_TO_MODEL: path to downloaded pretrained model or your own retrained model
        POSITIVE_FINDINGS_ONLY: dataloader will show only examples + for LABEL pathology if True, otherwise shows positive
                                and negative examples if false

    Returns:
        dataloader: dataloader with test examples to show
        model: fine tuned torchvision densenet-121
    """

    checkpoint = torch.load(PATH_TO_MODEL,
                            map_location=lambda storage, loc: storage)
    model = checkpoint['model']
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)

    del checkpoint
    model.cpu()

    # build dataloader on test
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    FINDINGS = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
        'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
        'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
    ]

    data_transform = transforms.Compose([
        transforms.Scale(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    if not POSITIVE_FINDINGS_ONLY:
        finding = "any"
    else:
        finding = LABEL

    dataset = CXR.CXRDataset(path_to_images=PATH_TO_IMAGES,
                             fold='test',
                             transform=data_transform,
                             finding=finding,
                             starter_images=STARTER_IMAGES)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1)

    return iter(dataloader), model
Exemple #2
0
    def train(self):
        torch.manual_seed(0)
        cudnn.benchmark = True
        self.model.train()

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        data_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ColorJitter(brightness=0.1,
                                   contrast=0.1,
                                   saturation=0.1,
                                   hue=0.1),
            transforms.RandomAffine(degrees=10, scale=(.95, 1.05), shear=0),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        transformed_dataset = CXR.CXRDataset(path_to_images=self.data_path,
                                             fold='train',
                                             transform=data_transforms)

        dataloader = torch.utils.data.DataLoader(transformed_dataset,
                                                 batch_size=16,
                                                 shuffle=True,
                                                 num_workers=8)

        lr = 0.01
        print(f"About to begin training on device: {self.device}")
        for epoch in range(20):
            for x, y, _ in dataloader:
                x, y = Variable(x.to(self.device)), Variable(y.to(
                    self.device)).float()
                output = self.model(x)

                loss = self.loss_fn(output, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            lr *= 0.95
            self.optimizer = get_optimizer(self.model, lr)
            print(f"Completed epoch on {self.device}")
def load_data(label='any'):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    PATH_TO_IMAGES = "dataset/images"
    finding = 'any'
    data_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
        ])

    labels = [
        'Atelectasis',
        'Cardiomegaly',
        'Effusion',
        'Infiltration',
        'Mass',
        'Nodule',
        'Pneumonia',
        'Pneumothorax',
        'Consolidation',
        'Edema',
        'Emphysema',
        'Fibrosis',
        'Pleural_Thickening',
        'Hernia']
    if label in labels:
        finding = label

    dataset = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='test',
        transform=data_transform, finding=finding)
    print('Total number of images is: {0}'.format(len(dataset.df)))

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    return iter(dataloader), len(dataset.df)
def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY):
    """
    Train torchvision model to NIH data given high level hyperparameters.

    Args:
        PATH_TO_IMAGES: path to NIH images
        LR: learning rate
        WEIGHT_DECAY: weight decay parameter for SGD

    Returns:
        preds: torchvision model predictions on test fold with ground truth for comparison
        aucs: AUCs for each train,test tuple

    """
    NUM_EPOCHS = 100  #100
    BATCH_SIZE = 32  #16

    try:
        rmtree('results/')
    except BaseException:
        pass  # directory doesn't yet exist, no need to clear it
    os.makedirs("results/")

    # use imagenet mean,std for normalization
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    N_LABELS = 15  # we are predicting 15 labels. Originally 14 before adding Covid

    # load labels
    df = pd.read_csv("nih_labels_modified.csv", index_col=0)

    # define torchvision transforms
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(224),  #was transforms.Scale
            # because scale doesn't always give 224 x 224, this ensures 224 x
            # 224
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        'val':
        transforms.Compose([
            transforms.Resize(224),  #was transforms.Scale
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }

    # create train/val dataloaders
    transformed_datasets = {}
    transformed_datasets['train'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='train',
        transform=data_transforms['train'])
    transformed_datasets['val'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='val',
        transform=data_transforms['val'])

    dataloaders = {}
    dataloaders['train'] = torch.utils.data.DataLoader(
        transformed_datasets['train'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8)
    dataloaders['val'] = torch.utils.data.DataLoader(
        transformed_datasets['val'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8)

    # please do not attempt to train without GPU as will take excessively long
    if not use_gpu:
        raise ValueError("Error, requires GPU")
    model = models.densenet121(pretrained=True)
    num_ftrs = model.classifier.in_features
    # add final layer with # outputs in same dimension of labels with sigmoid
    # activation
    model.classifier = nn.Sequential(nn.Linear(num_ftrs, N_LABELS),
                                     nn.Sigmoid())

    # put model on GPU
    model = model.cuda()

    # define criterion, optimizer for training
    criterion = nn.BCELoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=LR,
                          momentum=0.9,
                          weight_decay=WEIGHT_DECAY)
    dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']}

    # train model
    model, best_epoch = train_model(model,
                                    criterion,
                                    optimizer,
                                    LR,
                                    num_epochs=NUM_EPOCHS,
                                    dataloaders=dataloaders,
                                    dataset_sizes=dataset_sizes,
                                    weight_decay=WEIGHT_DECAY)

    # get preds and AUCs on test fold
    preds, aucs = E.make_pred_multilabel(data_transforms, model,
                                         PATH_TO_IMAGES)

    return preds, aucs
Exemple #5
0
def make_pred_multilabel(data_transforms, model, PATH_TO_IMAGES):
    """
    Gives predictions for test fold and calculates AUCs using previously trained model

    Args:
        data_transforms: torchvision transforms to preprocess raw images; same as validation transforms
        model: densenet-121 from torchvision previously fine tuned to training data
        PATH_TO_IMAGES: path at which NIH images can be found
    Returns:
        pred_df: dataframe containing individual predictions and ground truth for each test image
        auc_df: dataframe containing aggregate AUCs by train/test tuples
    """

    # calc preds in batches of 16, can reduce if your GPU has less RAM
    BATCH_SIZE = 16

    # set model to eval mode; required for proper predictions given use of batchnorm
    model.train(False)

    # create dataloader
    dataset = CXR.CXRDataset(path_to_images=PATH_TO_IMAGES,
                             fold="test",
                             transform=data_transforms['val'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=8)
    size = len(dataset)

    # create empty dfs
    pred_df = pd.DataFrame(columns=["Image Index"])
    true_df = pd.DataFrame(columns=["Image Index"])

    # iterate over dataloader
    for i, data in enumerate(dataloader):

        inputs, labels, _ = data
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

        true_labels = labels.cpu().data.numpy()
        batch_size = true_labels.shape

        outputs = model(inputs)
        probs = outputs.cpu().data.numpy()

        # get predictions and true values for each item in batch
        for j in range(0, batch_size[0]):
            thisrow = {}
            truerow = {}
            thisrow["Image Index"] = dataset.df.index[BATCH_SIZE * i + j]
            truerow["Image Index"] = dataset.df.index[BATCH_SIZE * i + j]

            # iterate over each entry in prediction vector; each corresponds to
            # individual label
            for k in range(len(dataset.PRED_LABEL)):
                thisrow["prob_" + dataset.PRED_LABEL[k]] = probs[j, k]
                truerow[dataset.PRED_LABEL[k]] = true_labels[j, k]

            pred_df = pred_df.append(thisrow, ignore_index=True)
            true_df = true_df.append(truerow, ignore_index=True)

        if (i % 10 == 0):
            print(str(i * BATCH_SIZE))

    auc_df = pd.DataFrame(columns=["label", "auc"])

    # calc AUCs
    for column in true_df:

        if column not in [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
        ]:
            continue
        actual = true_df[column]
        pred = pred_df["prob_" + column]
        thisrow = {}
        thisrow['label'] = column
        thisrow['auc'] = np.nan
        try:
            thisrow['auc'] = sklm.roc_auc_score(actual.as_matrix().astype(int),
                                                pred.as_matrix())
        except BaseException:
            print("can't calculate auc for " + str(column))
        auc_df = auc_df.append(thisrow, ignore_index=True)

    pred_df.to_csv("results/preds.csv", index=False)
    auc_df.to_csv("results/aucs.csv", index=False)
    return pred_df, auc_df
Exemple #6
0
def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY, fine_tune=False, regression=False, freeze=False, adam=False,
              initial_model_path=None, initial_brixia_model_path=None, weighted_cross_entropy_batchwise=False,
              modification=None, weighted_cross_entropy=False):
    """
    Train torchvision model to NIH data given high level hyperparameters.

    Args:
        PATH_TO_IMAGES: path to NIH images
        LR: learning rate
        WEIGHT_DECAY: weight decay parameter for SGD

    Returns:
        preds: torchvision model predictions on test fold with ground truth for comparison
        aucs: AUCs for each train,test tuple

    """
    NUM_EPOCHS = 100
    BATCH_SIZE = 32

    try:
        rmtree('results/')
    except BaseException:
        pass  # directory doesn't yet exist, no need to clear it
    os.makedirs("results/")

    # use imagenet mean,std for normalization
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    N_LABELS = 14  # we are predicting 14 labels
    N_COVID_LABELS = 3  # we are predicting 3 COVID labels

    # define torchvision transforms
    data_transforms = {
        'train': transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }

    # create train/val dataloaders
    transformed_datasets = {}
    transformed_datasets['train'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='train',
        transform=data_transforms['train'],
        fine_tune=fine_tune,
        regression=regression)
    transformed_datasets['val'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='val',
        transform=data_transforms['val'],
        fine_tune=fine_tune,
        regression=regression)

    dataloaders = {}
    dataloaders['train'] = torch.utils.data.DataLoader(
        transformed_datasets['train'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8)
    dataloaders['val'] = torch.utils.data.DataLoader(
        transformed_datasets['val'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8)

    # please do not attempt to train without GPU as will take excessively long
    if not use_gpu:
        raise ValueError("Error, requires GPU")

    if initial_model_path or initial_brixia_model_path:
        if initial_model_path:
            saved_model = torch.load(initial_model_path)
        else:
            saved_model = torch.load(initial_brixia_model_path)
        model = saved_model['model']
        del saved_model
        if fine_tune and not initial_brixia_model_path:
            num_ftrs = model.module.classifier.in_features
            if freeze:
                for feature in model.module.features:
                    for param in feature.parameters():
                        param.requires_grad = False
                    if feature == model.module.features.transition2:
                        break
            if not regression:
                model.module.classifier = nn.Linear(num_ftrs, N_COVID_LABELS)
            else:
                model.module.classifier = nn.Sequential(
                    nn.Linear(num_ftrs, 1),
                    nn.ReLU(inplace=True)
                )
    else:
        model = models.densenet121(pretrained=True)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, N_LABELS)

        if modification == 'transition_layer':
            # num_ftrs = model.features.norm5.num_features
            up1 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
                                      torch.nn.BatchNorm2d(num_ftrs),
                                      torch.nn.ReLU(True))
            up2 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
                                      torch.nn.BatchNorm2d(num_ftrs))

            transition_layer = torch.nn.Sequential(up1, up2)
            model.features.add_module('transition_chestX', transition_layer)

        if modification == 'remove_last_block':
            model.features.denseblock4 = nn.Sequential()
            model.features.transition3 = nn.Sequential()
            # model.features.norm5 = nn.BatchNorm2d(512)
            # model.classifier = nn.Linear(512, N_LABELS)
        if modification == 'remove_last_two_block':
            model.features.denseblock4 = nn.Sequential()
            model.features.transition3 = nn.Sequential()

            model.features.transition2 = nn.Sequential()
            model.features.denseblock3 = nn.Sequential()

            model.features.norm5 = nn.BatchNorm2d(512)
            model.classifier = nn.Linear(512, N_LABELS)

    print(model)

    # put model on GPU
    if not initial_model_path:
        model = nn.DataParallel(model)
    model.to(device)

    if regression:
        criterion = nn.MSELoss()
    else:
        if weighted_cross_entropy:
            pos_weights = transformed_datasets['train'].pos_neg_balance_weights()
            print(pos_weights)
            # pos_weights[pos_weights>40] = 40
            criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        else:
            criterion = nn.BCEWithLogitsLoss()

    if adam:
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
    else:
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY, momentum=0.9)

    dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']}

    # train model
    if regression:
        model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
                                        dataloaders=dataloaders, dataset_sizes=dataset_sizes,
                                        weight_decay=WEIGHT_DECAY, fine_tune=fine_tune, regression=regression)
    else:
        model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
                                        dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY,
                                        weighted_cross_entropy_batchwise=weighted_cross_entropy_batchwise,
                                        fine_tune=fine_tune)
        # get preds and AUCs on test fold
        preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune)
        return preds, aucs
Exemple #7
0
def run(PATH_TO_IMAGES, LR, WEIGHT_DECAY, opt):
    """
    Train torchvision model to NIH data given high level hyperparameters.

    Args:
        PATH_TO_IMAGES: path to NIH images
        LR: learning rate
        WEIGHT_DECAY: weight decay parameter for SGD

    Returns:
        preds: torchvision model predictions on test fold with ground truth for comparison
        aucs: AUCs for each train,test tuple

    """

    use_gpu = torch.cuda.is_available()
    gpu_count = torch.cuda.device_count()
    print("Available GPU count:" + str(gpu_count))

    wandb.init(project=opt.project, name=opt.run_name)
    wandb.config.update(opt, allow_val_change=True)

    NUM_EPOCHS = 60
    BATCH_SIZE = opt.batch_size

    if opt.eval_only:
        # test only. it is okay to have duplicate run_path
        os.makedirs(opt.run_path, exist_ok=True)
    else:
        # train from scratch, should not have the same run_path. Otherwise it will overwrite previous runs.
        try:
            os.makedirs(opt.run_path)
        except FileExistsError:
            print("[ERROR] run_path {} exists. try to assign a unique run_path".format(opt.run_path))
            return None, None
        except Exception as e:
            print("exception while creating run_path {}".format(opt.run_path))
            print(str(e))
            return None, None

    # use imagenet mean,std for normalization
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    N_LABELS = 14  # we are predicting 14 labels

    # define torchvision transforms
    if opt.random_crop:

        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(size=opt.input_size, scale=(0.8, 1.0)),  # crop then resize
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(int(opt.input_size * 1.05)),
                transforms.CenterCrop(opt.input_size),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
        }

    else:
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize(opt.input_size),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(opt.input_size),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
        }
    # create train/val dataloaders
    transformed_datasets = {}
    transformed_datasets['train'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='train',
        transform=data_transforms['train'])
    transformed_datasets['val'] = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='val',
        transform=data_transforms['val'])

    worker_init_fn = set_seed(opt)

    dataloaders = {}
    dataloaders['train'] = torch.utils.data.DataLoader(
        transformed_datasets['train'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=30,
        drop_last=True,
        worker_init_fn=worker_init_fn
    )
    dataloaders['val'] = torch.utils.data.DataLoader(
        transformed_datasets['val'],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=30,
        drop_last=True,
        worker_init_fn=worker_init_fn
    )

    # please do not attempt to train without GPU as will take excessively long
    if not use_gpu:
        raise ValueError("Error, requires GPU")

    # load model
    model = load_model(N_LABELS, opt)

    # define criterion, optimizer for training
    criterion = nn.BCELoss()

    optimizer = create_optimizer(model, LR, WEIGHT_DECAY, opt)

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'max',
        factor=opt.lr_decay_ratio,
        patience=opt.patience,
        verbose=True
    )

    dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']}

    if opt.eval_only:
        print("loading best model statedict")
        # load best model weights to return
        checkpoint_best = torch.load(os.path.join(opt.run_path, 'checkpoint'))
        model = load_model(N_LABELS, opt=opt)
        model.load_state_dict(checkpoint_best['state_dict'])

    else:
        # train model
        model, best_epoch = train_model(
            model,
            criterion,
            optimizer,
            LR,
            scheduler=scheduler,
            num_epochs=NUM_EPOCHS,
            dataloaders=dataloaders,
            dataset_sizes=dataset_sizes,
            PATH_TO_IMAGES=PATH_TO_IMAGES,
            data_transforms=data_transforms,
            opt=opt,
        )

    # get preds and AUCs on test fold
    preds, aucs = E.make_pred_multilabel(
        data_transforms,
        model,
        PATH_TO_IMAGES,
        fold="test",
        opt=opt,
    )

    wandb.log({
        'val_official': np.average(list(aucs.auc))
    })

    return preds, aucs
def load_data(PATH_TO_IMAGES,
              LABEL,
              PATH_TO_MODEL,
              fold,
              POSITIVE_FINDINGS_ONLY=None,
              covid=False):
    """
    Loads dataloader and torchvision model

    Args:
        PATH_TO_IMAGES: path to NIH CXR images
        LABEL: finding of interest (must exactly match one of FINDINGS defined below or will get error)
        PATH_TO_MODEL: path to downloaded pretrained model or your own retrained model
        POSITIVE_FINDINGS_ONLY: dataloader will show only examples + for LABEL pathology if True, otherwise shows positive
                                and negative examples if false

    Returns:
        dataloader: dataloader with test examples to show
        model: fine tuned torchvision densenet-121
    """

    checkpoint = torch.load(PATH_TO_MODEL,
                            map_location=lambda storage, loc: storage)
    model = checkpoint['model']
    del checkpoint
    model = model.module.to(device)
    # model.eval()
    # for param in model.parameters():
    #  param.requires_grad_(False)

    # build dataloader on test
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    data_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    if not covid:
        bounding_box_transform = CXR.RescaleBB(224, 1024)

        if not POSITIVE_FINDINGS_ONLY:
            finding = "any"
        else:
            finding = LABEL

        dataset = CXR.CXRDataset(path_to_images=PATH_TO_IMAGES,
                                 fold=fold,
                                 transform=data_transform,
                                 transform_bb=bounding_box_transform,
                                 finding=finding)
    else:
        dataset = CXR.CXRDataset(path_to_images=PATH_TO_IMAGES,
                                 fold=fold,
                                 transform=data_transform,
                                 fine_tune=True)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1)

    return iter(dataloader), model
def samples_display(LABEL, beta=6):

    dataset = CXR.CXRDataset(
        path_to_images=PATH_TO_IMAGES,
        fold='BBox',  #fold='train'
        transform=data_transforms['train'],
        transform_bb=bounding_box_transform,
        fine_tune=False,
        label_path=label_path,
        finding=LABEL)  #finding=LABEL
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)
    global _label
    _label = get_label(LABEL)

    seed = np.random.randint(0, 1000000000)

    np.random.seed(seed)

    length = len(dataset)
    print("20 Ransom samples of " + LABEL + " from " + str(length) + " images")

    for sample_idx in np.random.choice(length,
                                       20):  #change with dataset length
        iba = IBA(model.features.denseblock2)
        iba.reset_estimate()
        iba.estimate(model,
                     dataloader,
                     device=dev,
                     n_samples=length,
                     progbar=False)  #change with dataset length
        img, target, idx, bbox = dataset[sample_idx]
        img = img[None].to(dev)

        # reverse the data pre-processing for plotting the original image
        np_img = tensor_to_np_img(img[0])

        size = 4
        rows = 1
        cols = 5
        fig, (ax, ax2, ax3, ax4,
              ax5) = plt.subplots(1,
                                  5,
                                  figsize=(cols * (size + 1), rows * size))

        cxr = img.data.cpu().numpy().squeeze().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        cxr = std * cxr + mean
        cxr = np.clip(cxr, 0, 1)

        #    rect_original = patches.Rectangle((bbox[0, 0], bbox[0, 1]), bbox[0, 2], bbox[0, 3], linewidth=2, edgecolor='r', facecolor='none', zorder=2)
        rect_original = patches.Rectangle((bbox[0], bbox[1]),
                                          bbox[2],
                                          bbox[3],
                                          linewidth=2,
                                          edgecolor='r',
                                          facecolor='none',
                                          zorder=2)

        ax.imshow(cxr)
        ax.axis('off')
        ax.set_title(idx)
        ax.add_patch(rect_original)

        iba.reverse_lambda = False
        iba.beta = beta
        heatmap = iba.analyze(img, model_loss_closure)
        # show the heatmap
        ax2 = plot_saliency_map(heatmap, np_img, ax=ax2)
        _ = ax2.set_title("Old method, Block2")
        ax2.add_patch(
            patches.Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              linewidth=2,
                              edgecolor='r',
                              facecolor='none',
                              zorder=2))

        iba.reverse_lambda = True
        iba.beta = beta
        heatmap = iba.analyze(img, model_loss_closure)
        # show the heatmap
        ax3 = plot_saliency_map(heatmap, np_img, ax=ax3)
        _ = ax3.set_title("New method, Block2")
        ax3.add_patch(
            patches.Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              linewidth=2,
                              edgecolor='r',
                              facecolor='none',
                              zorder=2))

        iba = IBA(model.features.denseblock3)
        iba.reset_estimate()
        iba.estimate(model,
                     dataloader,
                     device=dev,
                     n_samples=length,
                     progbar=False)

        iba.reverse_lambda = False
        iba.beta = beta
        heatmap = iba.analyze(img, model_loss_closure)
        # show the heatmap
        ax4 = plot_saliency_map(heatmap, np_img, ax=ax4)
        _ = ax4.set_title("Old method, Block3")
        ax4.add_patch(
            patches.Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              linewidth=2,
                              edgecolor='r',
                              facecolor='none',
                              zorder=2))

        iba.reverse_lambda = True
        iba.beta = beta
        heatmap = iba.analyze(img, model_loss_closure)
        # show the heatmap
        ax5 = plot_saliency_map(heatmap, np_img, ax=ax5)
        _ = ax5.set_title("New method, Block3")
        ax5.add_patch(
            patches.Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              linewidth=2,
                              edgecolor='r',
                              facecolor='none',
                              zorder=2))

        plt.show()
    return length
                              zorder=2))

        plt.show()
    return length


"""# Cardiomegaly"""

LABEL = 'Cardiomegaly'

length_data = samples_display(LABEL)

dataset = CXR.CXRDataset(
    path_to_images=PATH_TO_IMAGES,
    fold='BBox',  #fold='train'
    transform=data_transforms['train'],
    transform_bb=bounding_box_transform,
    fine_tune=False,
    label_path=label_path,
    finding=LABEL)  #finding=LABEL
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
img, target, idx, bbox = dataset[0]
"""## Compare different betas (So I chose beta=6 instead of 10, the default)

###IB in denseblock3 with different betas
"""

iba = IBA(model.features.denseblock3)
iba.reset_estimate()
iba.estimate(model,
             dataloader,
             device=dev,
def make_pred_multilabel(model, PATH_TO_IMAGES, BATCH_SIZE=4):
    """
    Gives predictions for test fold and calculates AUCs using previously trained model

    Args:
        data_transforms: torchvision transforms to preprocess raw images; same as validation transforms
        model: the model to calculate the AUCs for
        PATH_TO_IMAGES: path at which NIH images can be found
    Returns:
        pred_df: dataframe containing individual predictions and ground truth for each test image
        auc_df: dataframe containing aggregate AUCs by train/test tuples
    """

    # calc preds in batches of 16, can reduce if your GPU has less RAM

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(224),
            # because scale doesn't always give 224 x 224, this ensures 224 x
            # 224
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        'val':
        transforms.Compose([
            # transforms.Resize(224),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }

    # set model to eval mode; required for proper predictions given use of batchnorm
    model.train(False)

    # create dataloader
    dataset = CXR.CXRDataset(path_to_images=PATH_TO_IMAGES,
                             fold="test",
                             sample=200,
                             transform=data_transforms['val'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=8)
    size = len(dataset)
    print(size)
    # print('datasetsize', dataset_sizes)
    print(dataset.df)

    # create empty dfs
    pred_df = pd.DataFrame(columns=["Image Index"])
    true_df = pd.DataFrame(columns=["Image Index"])

    # these lists will save values for the second way of calculating the scores
    outputList = []
    labelList = []

    # iterate over dataloader
    for idx, data in enumerate(dataloader):

        inputs, labels, _ = data
        inputs = Variable(inputs.to(device))
        labels = Variable(labels.float().to(device))

        true_labels = labels.cpu().data.numpy()
        batch_size = true_labels.shape
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs)
        probs = outputs.cpu().data.numpy()

        for i in range(outputs.shape[0]):
            outputList.append(outputs[i].tolist())
            labelList.append(labels[i].tolist())

        # get predictions and true values for each item in batch
        # here the dataframe 'true' and 'preds' are created by adding rows into them respectively
        for j in range(0, batch_size[0]):
            thisrow = {}
            truerow = {}
            thisrow["Image Index"] = dataset.df.index[BATCH_SIZE * idx + j]
            truerow["Image Index"] = dataset.df.index[BATCH_SIZE * idx + j]

            # iterate over each entry in prediction vector; each corresponds to
            # individual label
            for k in range(len(dataset.PRED_LABEL)):
                thisrow["prob_" + dataset.PRED_LABEL[k]] = probs[j, k]
                truerow[dataset.PRED_LABEL[k]] = true_labels[j, k]
            pred_df = pred_df.append(thisrow, ignore_index=True)
            true_df = true_df.append(truerow, ignore_index=True)
            # print(pred_df)
            # print(head(true_df))
        if (idx % 100 == 0):
            print(str(idx * BATCH_SIZE))

    #Another way to calculate the AUCs. The calculation is done step by step.
    print('Scores - Method2 -----------------------')
    epoch_auc_ave = sklm.roc_auc_score(np.array(labelList),
                                       np.array(outputList))
    epoch_auc = sklm.roc_auc_score(np.array(labelList),
                                   np.array(outputList),
                                   average=None)
    for i, c in enumerate(dataset.PRED_LABEL):
        fpr, tpr, _ = sklm.roc_curve(
            np.array(labelList)[:, i],
            np.array(outputList)[:, i])
        plt.plot(fpr, tpr, label=c)
        print('{}: {:.4f} '.format(c, epoch_auc[i]))
    print('Scores - Method2 -----------------------')

    # here the auc scores are calculated and the 'auc' table is created
    auc_df = pd.DataFrame(columns=["label", "auc"])

    # calc AUCs
    for column in true_df:

        if column not in [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
        ]:
            continue
        actual = true_df[column]
        pred = pred_df["prob_" + column]
        thisrow = {}
        thisrow['label'] = column
        thisrow['auc'] = np.nan
        try:
            thisrow['auc'] = sklm.roc_auc_score(actual.values.astype(int),
                                                pred.values)
        except BaseException as e:
            print('-------------------')
            print(e)
            print(actual.values)
            print(pred.values)
            print('-------------------')

        auc_df = auc_df.append(thisrow, ignore_index=True)

    pred_df.to_csv("results/preds.csv", index=False)
    auc_df.to_csv("results/aucs.csv", index=False)
    true_df.to_csv('results/true.csv', index=False)
    return pred_df, auc_df, true_df