Skip to content

mafanhe/pbt

Repository files navigation

PBT: Population Based Training

Original paper: Population Based Training of Neural Networks, Jaderberg et al.
This repo is forked from MattKleinsmith's pbt repo,differences or enhancement with it is:

1.implement by tensorflow.
2.speed up trainning by multiprocess.

What this code is

A Tensorflow implementation of PBT.

What this code is for

Finding a good hyperparameter schedule.

How to use this code

Warning: This implementation isn't user friendly yet. If you have any questions, create a github issue and I'll try to help you.

Steps:

  1. Wrestle with dependencies.
  2. Edit config.py to set your options.
  3. Store your data as bcolz carrays. See datasets.py for an example.
  4. In a terminal, enter: python main.py --exploiter
  5. run parameters: --population_id -1, where "1" refers to your GPU's ID in nvidia-smi, and "-1" means to work on the most recently created population. --exploiter, Set this process as the exploiter. It will be responsible for running the exploit step over the entire population at the end of each interval. When finished, the process will print the path to the weights of the best performing model.

Figures for intuition

png

png

These figures are for building intuition for PBT. They aren't the results of a rigorous experiment. In the accuracy plots, the best model is shown in purple. In the hyperparameter scatter plots, the size of the dots grow as the models train. The hyperparameter configurations of the best model from each population are purple stars.

Notice how the hyperparameter configurations evolve in PBT, but stay the same in random search.

How does PBT work?

PBT trains each model partially and assesses them on the validation set. It then transfers the parameters and hyperparameters from the top performing models to the bottom performing models (exploitation). After transferring the hyperparameters, PBT perturbs them (exploration). Each model is then trained some more, and the process repeats. This allows PBT to learn a hyperparameter schedule instead of only a fixed hyperparameter configuration. PBT can be used with different selection methods (e.g. different ways of defining "top" and "bottom" (e.g. top 5, top 5%, etc.)).

For more information, see the paper or blog post.

Selection method in this code

Truncation selection: For each model in the bottom 20% of performance, sample a model from the top 20% and transfer its parameters and hyperparameters to the worse model. One can think of the models in the bottom 20% as being truncated during each exploitation step. Leave the top 80% unchanged. This selection method was used in the paper.

About the figures

The figures above were produced with a naive PBT selection method: select the best model each time. The accuracy improves to around 99.35% with the selection method in the paper: truncation selection. Seeds will change results. A simple conv net was used. Dataset: MNIST.

I produced these figures using an old and very different version of this repo. I haven't yet re-added logging and plotting.

The essence of this code

Managing tasks via a sqlite database table.

Each task corresponds to training a model for an epoch. These tasks run in parallel. Once in a while the exploiter process truncates the worst-performing models, which blocks other processes from training models for a bit. That makes it 99% parallel instead of 100% parallel like random search.

Since this code is mostly about task management, it isn't very tied to a particular deep learning framework. With a little work, one could replace the Tensorflow ties with other framework. However, this assumes you have your hyperparameter space defined in your framework of choice, which is what you need for any hyperparameter optimization algorithm, including random search. As of this writing, the hyperparameter space in this code only has two dimensions: learning rate and momentum coefficient.

About

Population Based Training (in Tensorflow with sqlite3)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages