Skip to content

UCIDataLab/bayesian-blackbox

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Active Bayesian Assessment for Black-Box Classifiers

This repo contains an implementation of the active Bayesian assessment models described in "Active Bayesian Assessment for Black-Box Classifiers", Disi Ji, Robert L. Logan IV, Padhraic Smyth, Mark Steyvers. [arXiv].

Setup

You will need Python 3.5+. Dependencies can be installed by running:

pip install -r requirements.txt

Data

We assess performance characteristics of neural models on several standard image and text classification datasets. The image datasets we use are: CIFAR-100 (Krizhevsky and Hinton, 2009), SVHN (Netzer et al., 2011) and ImageNet (Russakovsky et al., 2015). The text datasets we use are: 20 Newsgroups (Lang, 1995) and DBpedia (Zhang et al., 2015).

For image classification we use ResNet (He et al., 2016) architectures with either 110 layers (CIFAR-100) or 152 layers (SVHN and ImageNet). For ImageNet we use the pretrained model provided by PyTorch, and for CIFAR and SVHN we use the pretrained model checkpoints provided at: https://github.com/bearpaw/pytorch-classification. For text classification tasks we use fine-tuned BERTBASE (Devlin et al., 2019) models. These models were all trained on standard training sets in the literature, independent from the datasets used for assessment.

The assessment datasets are based on standard test sets used for each dataset in the literature. We build our assessment models on predictions that the classifiers made on the assessment datasets. Predictions can be downloaded from here.

Run the Active Bayesian Assessor

After downloading the data, update DATA_DIR, RESULTS_DIR and FIGURE_DIR and src/data_utils.py accordingly, to specify the input directory to read data from and the output output directory to write results and figures to.

To reproduce all the experimental results and figures we reported in the paper, run commands in script.

For example, to identify the extreme classes, navigate to src directory and run:

python active_learning_topk.py [dataset] 
    -metric [metric] 
    -mode [mode] 
    -topk [topk] 
    -pseudocount=[pseudocount]

where

  • dataset: name of the dataset. Path to the corresponding input data is hard-coded in src/data_utils.py.
  • metric: 'accuracy' or 'calibration_error', the metric used to rank classes.
  • mode: 'min' or 'max', to identify the most/least accurate/calibrated classes.
  • topk: the number of extreme classes to identify.
  • pseudocount: the strength of the prior.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%