Skip to content

This repository contains the code for our paper "Semi-Supervised Learning with Variational Bayesian Inference and Maximum Uncertainty Regularization"

Notifications You must be signed in to change notification settings

clarken92/ConsistencySSL

Repository files navigation

Semi-Supervised Learning with Variational Bayesian Inference and Maximum Uncertainty Regularization

This repository contains the official implementation of our paper:

Semi-Supervised Learning with Variational Bayesian Inference and Maximum Uncertainty Regularization

Kien Do, Truyen Tran, Svetha Venkatesh

Accepted at AAAI 2021.

Contents

  1. Requirements
  2. Features
  3. Repository structure
  4. Setup
  5. Downloading and preprocessing data
  6. Training
  7. Citation

Requirements

Tensorflow >= 1.8

The code hasn't been tested with Tensorflow 2.

This repository is designed to be self-contained. If during running the code, some packages are required, these packages can be downloaded via pip or conda. Please email me if you find any problems related to this.

Features

  • Support model saving
  • Support logging
  • Support tensorboard visualization

Repository structure

Our code is organized in 5 main parts:

  • models: Containing models used in our paper, including Pi, MT, MT+VD, MT+MUR,....
  • components: Containing implementation for the CNN13 classifier.
  • my_utils: Containing utility functions.
  • data_preparation: Containing code for downloading and preprocessing datasets.
  • working: Containing scripts for training models.

IMPORTANT NOTE: Since this repository is organized as a Python project, I strongly encourage you to import it as a project to an IDE (e.g., PyCharm). By doing so, the path to the root folder of this project will be automatically added to PYTHONPATH when you run the code via your IDE. Otherwise, you have to explicitly add it when you run in terminal. Please check run_cifar10.sh (or run_cifar100.sh, run_svhn.sh) to see how it works.

Setup

The setup for training is very simple. All you need to do is opening the global_settings.py file and changing the values of the global variables to match yours. The meanings of the global variables are given below:

  • PYTHON_EXE: Path to your python interpreter.
  • PROJECT_NAME: Name of the project, which I set to be 'ConsistencySSL'.
  • PROJECT_DIR: Path to the root folder containing the code of this project.
  • RESULTS_DIR: Path to the root folder that will be used to store results for this project.
  • RAW_DATA_DIR: Path to the root folder that contains raw datasets. By default, the root directory of the CIFAR10/CIFAR100/SVHN dataest is $RAW_DATA_DIR/ComputerVision/[dataset_name].

Downloading and preprocessing data

Before training, you need to download and preprocess datasets. Scripts for each dataset are provided in data_preparation/[dataset name]. You simply need to run them in order.

For example, to prepare the CIFAR10 dataset, run the following commands:

export PYTHONPATH="[path to this project]:$PYTHONPATH"
python 1_process_data.py
python 2_generate_zca.py

Training

Once you have setup everything in global_settings.py, you can start training by running the following commands in your terminal:

export PYTHONPATH="[path to this project]:$PYTHONPATH"
python train.py [required arguments]

IMPORTANT NOTE: If you run using the commands above, please remember to provide all required arguments specified in train.py otherwise errors will be raised.

However, if you are too lazy to type arguments in the terminal (like me 😅), you can set these arguments in the run_config dictionary in run_cifar10.py (or run_cifar100.py, run_svhn.py) and simply run this file:

export PYTHONPATH="[path to this project]:$PYTHONPATH"
python run_cifar10.py

I also provide a run_cifar10.sh file as an example for you.

Citation

If you find this repository useful for your research, please consider citing our paper:

@inproceedings{do2021semi,
  title={Semi-Supervised Learning with Variational Bayesian Inference and Maximum Uncertainty Regularization},
  author={Do, Kien and Tran, Truyen and Venkatesh, Svetha},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={8},
  pages={7236--7244},
  year={2021}
}

About

This repository contains the code for our paper "Semi-Supervised Learning with Variational Bayesian Inference and Maximum Uncertainty Regularization"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published