Esempio n. 1
0
def train_bimcvpadchest(seed, alexnet=False, freeze_features=False):
    trainds = DomainConfoundedDataset(
        PadChestH5Dataset(fold='train',
                          labels='chestx-ray14',
                          random_state=seed),
        BIMCVCOVIDDataset(fold='train',
                          labels='chestx-ray14',
                          random_state=seed))
    valds = DomainConfoundedDataset(
        PadChestH5Dataset(fold='val', labels='chestx-ray14',
                          random_state=seed),
        BIMCVCOVIDDataset(fold='val', labels='chestx-ray14',
                          random_state=seed))

    # generate log and checkpoint paths
    if alexnet: netstring = 'alexnet'
    elif freeze_features: netstring = 'densenet121frozen'
    else: netstring = 'densenet121'
    logpath = 'logs/bimcvpadchest.{:s}.{:d}.log'.format(netstring, seed)
    checkpointpath = 'checkpoints/bimcvpadchest.{:s}.{:d}.pkl'.format(
        netstring, seed)

    classifier = CXRClassifier()
    classifier.train(trainds,
                     valds,
                     max_epochs=30,
                     lr=0.01,
                     weight_decay=1e-4,
                     logpath=logpath,
                     checkpoint_path=checkpointpath,
                     verbose=True,
                     scratch_train=alexnet,
                     freeze_features=freeze_features)
Esempio n. 2
0
def train_dataset_3(seed, alexnet=False, freeze_features=False):
    # Unlike the other datasets, there is overlap in patients between the
    # BIMCV-COVID-19+ and BIMCV-COVID-19- datasets, so we have to perform the 
    # train/val/test split *after* creating the datasets.

    # Start by getting the *full* dataset - not split!
    trainds = DomainConfoundedDataset(
            BIMCVNegativeDataset(fold='all', labels='chestx-ray14', random_state=seed),
            BIMCVCOVIDDataset(fold='all', labels='chestx-ray14', random_state=seed)
            )
    valds = DomainConfoundedDataset(
            BIMCVNegativeDataset(fold='all', labels='chestx-ray14', random_state=seed),
            BIMCVCOVIDDataset(fold='all', labels='chestx-ray14', random_state=seed)
            )
    # split on a per-patient basis
    trainvaldf1, testdf1, trainvaldf2, testdf2 = ds3_grouped_split(trainds.ds1.df, trainds.ds2.df, random_state=seed)
    traindf1, valdf1, traindf2, valdf2 = ds3_grouped_split(trainvaldf1, trainvaldf2, random_state=seed)

    # Update the dataframes to respect the per-patient splits
    trainds.ds1.df = traindf1
    trainds.ds2.df = traindf2
    valds.ds1.df = valdf1
    valds.ds2.df = valdf2
    trainds.len1 = len(trainds.ds1)
    trainds.len2 = len(trainds.ds2)
    valds.len1 = len(valds.ds1)
    valds.len2 = len(valds.ds2)

    # generate log and checkpoint paths
    if alexnet: netstring = 'alexnet'
    elif freeze_features: netstring = 'densenet121frozen'
    else: netstring = 'densenet121'
    logpath = 'logs/dataset3.{:s}.{:d}.log'.format(netstring, seed)
    checkpointpath = 'checkpoints/dataset3.{:s}.{:d}.pkl'.format(netstring, seed)

    classifier = CXRClassifier()
    classifier.train(trainds,
                valds,
                max_epochs=30,
                lr=0.01, 
                weight_decay=1e-4,
                logpath=logpath,
                checkpoint_path=checkpointpath,
                verbose=True,
                scratch_train=alexnet,
                freeze_features=freeze_features)
Esempio n. 3
0
def test_githubcxr14(seed, alexnet=False, freeze_features=False):
    internal_testds = DomainConfoundedDataset(
        ChestXray14H5Dataset(fold='test',
                             labels='chestx-ray14',
                             random_state=seed),
        GitHubCOVIDDataset(fold='test',
                           labels='chestx-ray14',
                           random_state=seed))

    external_testds = DomainConfoundedDataset(
        PadChestH5Dataset(fold='test',
                          labels='chestx-ray14',
                          random_state=seed),
        BIMCVCOVIDDataset(fold='test',
                          labels='chestx-ray14',
                          random_state=seed))

    # generate checkpoint path
    if alexnet: netstring = 'alexnet'
    elif freeze_features: netstring = 'densenet121frozen'
    else: netstring = 'densenet121'
    checkpointpath = 'checkpoints/githubcxr14.{:s}.{:d}.pkl.best_auroc'.format(
        netstring, seed)

    classifier = CXRClassifier()
    classifier.load_checkpoint(checkpointpath)

    internal_probs = classifier.predict(internal_testds)
    internal_true = internal_testds.get_all_labels()
    internal_idx = _find_index(internal_testds, 'COVID')
    internal_auroc = sklearn.metrics.roc_auc_score(
        internal_true[:, internal_idx], internal_probs[:, internal_idx])
    print("internal auroc: ", internal_auroc)

    external_idx = _find_index(external_testds, 'COVID')
    external_true = external_testds.get_all_labels()
    # a little hacky here!
    external_testds.labels = internal_testds.labels
    external_probs = classifier.predict(external_testds)
    external_auroc = sklearn.metrics.roc_auc_score(
        external_true[:, external_idx],
        external_probs[:,
                       internal_idx]  # not a typo! this *should* be internal_idx
    )
    print("external auroc: ", external_auroc)
Esempio n. 4
0
def plot(ax, checkpointpath, seed, legend=False):
    githubcxr14_testds = DomainConfoundedDataset(
        ChestXray14H5Dataset(fold='test',
                             labels='chestx-ray14',
                             random_state=seed),
        GitHubCOVIDDataset(fold='test',
                           labels='chestx-ray14',
                           random_state=seed))

    bimcvpadchest_testds = DomainConfoundedDataset(
        PadChestH5Dataset(fold='test',
                          labels='chestx-ray14',
                          random_state=seed),
        BIMCVCOVIDDataset(fold='test',
                          labels='chestx-ray14',
                          random_state=seed))

    classifier = CXRClassifier()
    classifier.load_checkpoint(checkpointpath)

    githubcxr14_probs = classifier.predict(githubcxr14_testds)
    githubcxr14_true = githubcxr14_testds.get_all_labels()
    githubcxr14_idx = _find_index(githubcxr14_testds, 'COVID')

    githubcxr14_auroc = sklearn.metrics.roc_auc_score(
        githubcxr14_true[:, githubcxr14_idx],
        githubcxr14_probs[:, githubcxr14_idx])
    print("githubcxr14 auroc: ", githubcxr14_auroc)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(
        githubcxr14_true[:, githubcxr14_idx],
        githubcxr14_probs[:, githubcxr14_idx])
    kwargs = {'color': '#b43335', 'linewidth': 1}
    if legend: ax.plot(fpr, tpr, label='ChestX-ray14/\nGitHub-COVID', **kwargs)
    else: ax.plot(fpr, tpr, **kwargs)

    bimcvpadchest_probs = classifier.predict(bimcvpadchest_testds)
    bimcvpadchest_true = bimcvpadchest_testds.get_all_labels()
    bimcvpadchest_idx = _find_index(bimcvpadchest_testds, 'COVID')
    bimcvpadchest_auroc = sklearn.metrics.roc_auc_score(
        bimcvpadchest_true[:, bimcvpadchest_idx],
        bimcvpadchest_probs[:, githubcxr14_idx])
    print("bimcvpadchest auroc: ", bimcvpadchest_auroc)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(
        bimcvpadchest_true[:, bimcvpadchest_idx],
        bimcvpadchest_probs[:, bimcvpadchest_idx])
    kwargs = {'color': '#107f80', 'linewidth': 1}
    if legend: ax.plot(fpr, tpr, label='PadChest/\nBIMCV-COVID-19+', **kwargs)
    else: ax.plot(fpr, tpr, **kwargs)
    return githubcxr14_auroc, bimcvpadchest_auroc
Esempio n. 5
0
def plot(ax, checkpointpath, seed, legend=False):
    githubcxr14_testds = DomainConfoundedDataset(
        ChestXray14H5Dataset(fold='test',
                             labels='chestx-ray14',
                             random_state=seed),
        GitHubCOVIDDataset(fold='test',
                           labels='chestx-ray14',
                           random_state=seed))

    bimcv_testds = DomainConfoundedDataset(
        BIMCVNegativeDataset(fold='test',
                             labels='chestx-ray14',
                             random_state=seed),
        BIMCVCOVIDDataset(fold='test',
                          labels='chestx-ray14',
                          random_state=seed))

    # Unlike the other datasets, there is overlap in patients between the
    # BIMCV-COVID-19+ and BIMCV-COVID-19- datasets, so we have to perform the
    # train/val/test split *after* creating the datasets.

    # Start by getting the *full* dataset - not split!
    bimcv_testds = DomainConfoundedDataset(
        BIMCVNegativeDataset(fold='all',
                             labels='chestx-ray14',
                             random_state=seed),
        BIMCVCOVIDDataset(fold='all', labels='chestx-ray14',
                          random_state=seed))
    # split on a per-patient basis
    trainvaldf1, testdf1, trainvaldf2, testdf2 = ds3_grouped_split(
        bimcv_testds.ds1.df, bimcv_testds.ds2.df, random_state=seed)

    # Update the dataframes to respect the per-patient splits
    bimcv_testds.ds1.df = testdf1
    bimcv_testds.ds2.df = testdf2
    bimcv_testds.len1 = len(bimcv_testds.ds1)
    bimcv_testds.len2 = len(bimcv_testds.ds2)

    classifier = CXRClassifier()
    classifier.load_checkpoint(checkpointpath)

    githubcxr14_probs = classifier.predict(githubcxr14_testds)
    print(githubcxr14_probs.shape)
    githubcxr14_true = githubcxr14_testds.get_all_labels()
    githubcxr14_idx = _find_index(githubcxr14_testds, 'COVID')

    githubcxr14_auroc = sklearn.metrics.roc_auc_score(
        githubcxr14_true[:, githubcxr14_idx],
        githubcxr14_probs[:, githubcxr14_idx])
    print("githubcxr14 auroc: ", githubcxr14_auroc)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(
        githubcxr14_true[:, githubcxr14_idx],
        githubcxr14_probs[:, githubcxr14_idx])
    kwargs = {'color': '#b43335', 'linewidth': 1}
    if legend: ax.plot(fpr, tpr, label='ChestX-ray14/\nGitHub-COVID', **kwargs)
    else: ax.plot(fpr, tpr, **kwargs)

    bimcv_probs = classifier.predict(bimcv_testds)
    bimcv_true = bimcv_testds.get_all_labels()
    bimcv_idx = _find_index(bimcv_testds, 'COVID-19')
    bimcv_auroc = sklearn.metrics.roc_auc_score(bimcv_true[:, bimcv_idx],
                                                bimcv_probs[:, bimcv_idx])
    print("bimcv auroc: ", bimcv_auroc)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(bimcv_true[:, bimcv_idx],
                                                     bimcv_probs[:, bimcv_idx])
    kwargs = {'color': '#1e579a', 'linewidth': 1}
    if legend:
        ax.plot(fpr, tpr, label='BIMCV-COVID-19−/\nBIMCV-COVID-19+', **kwargs)
    else:
        ax.plot(fpr, tpr, **kwargs)
    return githubcxr14_auroc, bimcv_auroc