Exemplo n.º 1
0
def main(**kwargs):
    parser = argparse.ArgumentParser()
    parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true")
    parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str)
    parser.add_argument("-opt", "--optimizer", default="entropy-sgd", help="Selected optimizer", type=str,
        choices=['entropy-sgd', 'adam', 'momentum', 'sgd'])
    parser.add_argument("-n", "--name", default="entropy-sgd", help="Checkpoint/Tensorboard label")
    parser.add_argument("-d", "--dataset", default="cifar10", help="Dataset to train on (cifar10 || cifar100)",
            type=str, choices=['cifar10', 'cifar100'])
    parser.add_argument("-L", "--langevin_iterations", default=20, help="Number of Langevin iterations in inner loop.",
            type=int)
    args = parser.parse_args()
    config = config_train

    architecture = 'Layers: {} | Conv dropout: {} | Base LR: {} | SGLD Iterations {} | Epochs: {} | Optimizer: {}'.format(
                    config.n_layers,
                    config.conv_keep_prob,
                    config.learning_rate,
                    config.L,
                    config.num_epochs,
                    args.optimizer
    )

    Diagnostics.setup_dataset(args.dataset)

    # Launch training
    train(config_train, architecture, args)