Exemplo n.º 1
0
def test_autoencode_vgg16_celeba():

    args = config.config([
        '--config', '../configs/autoencoder/vgg16/celeba.yaml', '--epochs',
        '5', '--dataroot', '../data', '--dataset_test_len', '256',
        '--dataset_train_len', '256'
    ])
    best_loss = train_autoencoder.main(args)
    assert best_loss < 0.04
Exemplo n.º 2
0
def test_autoencode_fc_mnist():

    args = config.config([
        '--config', '../configs/autoencoder/fc/mnist.yaml', '--epochs', '80',
        '--dataroot', '../data', '--dataset_test_len', '256',
        '--dataset_train_len', '256'
    ])
    best_loss = train_autoencoder.main(args)
    assert best_loss < 1.3
Exemplo n.º 3
0
def test_autoencode_resnet_cifar10():

    args = config.config([
        '--config', '../configs/autoencoder/resnet/cifar10.yaml', '--epochs',
        '20', '--optim_lr', '1e-5', '--dataroot', '../data',
        '--dataset_test_len', '256', '--dataset_train_len', '256', '--seed',
        '0'
    ])
    best_loss = train_autoencoder.main(args)
    assert best_loss < 1.2
Exemplo n.º 4
0
def test_mnist_fc():
    args = config.config(['--config', '../configs/classify/fc/mnist.yaml',
                          '--epochs', '3',
                          '--dataroot', '../data',
                          '--dataset_test_len', '256',
                          '--dataset_train_len', '256',
                          '--seed', '0',
                          '--run_id', '3'
                          ])
    ave_precision, best_precision, train_accuracy, test_accuracy = train_classifier.main(args)
    assert ave_precision > 0.2
    assert best_precision > 0.2
Exemplo n.º 5
0
def test_cifar10_vgg16():
    args = config.config(['--config', '../configs/classify/vgg16/cifar10.yaml',
                          '--optim_lr', '0.05',
                          '--epochs', '6',
                          '--dataroot', '../data',
                          '--dataset_test_len', '256',
                          '--dataset_train_len', '256',
                          '--seed', '0',
                          '--run_id', '1'
                          ])
    ave_precision, best_precision, train_accuracy, test_accuracy = train_classifier.main(args)
    assert ave_precision > 0.2
    assert best_precision > 0.2
Exemplo n.º 6
0
def test_cifar10_resnet():
    args = config.config(['--config', '../configs/classify/resnet/cifar10-batchnorm.yaml',
                          '--epochs', '80',
                          '--optim_lr', '0.01',
                          '--dataroot', '../data',
                          '--dataset_test_len', '256',
                          '--dataset_train_len', '256',
                          '--seed', '0',
                          '--run_id', '4'
                          ])
    ave_precision, best_precision, train_accuracy, test_accuracy = train_classifier.main(args)

    """ WARNING this model does not run reliably due to the shortcut containing convnets"""
    assert best_precision > 0.13
    assert train_accuracy > 20.0
Exemplo n.º 7
0
                'optimizer': optim.state_dict(),
                'amp': amp.state_dict()
            }
            torch.save(best, run_dir + '/best_amp.pt')

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                print('PYGAME QUIT')
                pygame.quit()
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_c and pygame.key.get_mods(
                ) & pygame.KMOD_CTRL:
                    print("pressed CTRL-C as an event")
                    pygame.quit()

    return ave_precision, best_precision, train_accuracy, test_accuracy


if __name__ == '__main__':
    """ configuration """
    args = config.config()
    pygame.init()
    wandb.init(project='bald-classification', name=args.name)
    wandb.config.update(args)
    Digo.init(api_key='qpqD6cHYvGG2XNHeGlHog8360T8uWHya',
              project_name='HumanSegmantation',
              workspace_name='Yunsang')

    torch.cuda.set_device(args.device)
    main(args)