Exemple #1
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='eigensolver_xxz1d-4-1.0_30_10',
            model_init='kaiming_normal',
            batchnorm_init=
            'uniform'  # Does not matter, we don't have module named 'BatchNorm2d'
        )

        dataset_hparams = hparams.DatasetHparams(dataset_name='placeholder',
                                                 batch_size=3)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            lr=1e-1,
            #weight_decay=1e-4,
            training_steps='100ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #2
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar_vgg_16',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(dataset_name='cifar10',
                                                 batch_size=128)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='80ep,120ep',
            lr=0.1,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='160ep')

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight')

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #3
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar10_simplecnn_16_32',
            model_init='kaiming_normal',
            batchnorm_init='uniform'
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='cifar10',
            batch_size=128
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=1e-2,
            training_steps='5ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams)
    def default_hparams():
        """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet.

        To get these results with a smaller batch size, scale the batch size linearly.
        That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc.
        """

        model_hparams = hparams.ModelHparams(
            model_name='imagenet_resnet_50',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='imagenet',
            batch_size=1024,
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='30ep,60ep,80ep',
            lr=0.4,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='90ep',
            warmup_steps='5ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #5
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(model_name='mnist_lenet5',
                                             model_init='kaiming_normal',
                                             batchnorm_init='uniform')

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='mnist',
            batch_size=128,
            # resize_input=False
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=0.001,
            training_steps='20ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            # pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #6
0
    def default_hparams(model_name):
        """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet and XX.X% top-1 accuracy on TinyImageNet.

        To get these results with a smaller batch size, scale the learning rate linearly.
        That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc.
        """

        # Model hyperparameters.
        model_hparams = hparams.ModelHparams(
            model_name=model_name,
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        # Dataset hyperparameters.
        if model_name.startswith('imagenet'):
            dataset_hparams = hparams.DatasetHparams(dataset_name='imagenet',
                                                     batch_size=1024)
        elif model_name.startswith('tinyimagenet'):
            dataset_hparams = hparams.DatasetHparams(
                dataset_name='tinyimagenet', batch_size=256)

        # Training hyperparameters.
        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='30ep,60ep,80ep',
            lr=0.4,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='90ep',
            warmup_steps='5ep',
        )

        if model_name.startswith('tinyimagenet'):
            training_hparams.training_steps = '200ep'
            training_hparams.milestone_steps = '100ep,150ep'
            training_hparams.lr = 0.2

        # Pruning hyperparameters.
        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #7
0
    def default_hparams(model_name):
        model_hparams = hparams.ModelHparams(model_name=model_name,
                                             model_init='kaiming_normal',
                                             batchnorm_init='uniform')

        dataset_hparams = hparams.DatasetHparams(dataset_name='mnist',
                                                 batch_size=128)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            lr=0.1,
            training_steps='16ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #8
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='fashionmnist_resnet_20',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='fashionmnist',
            batch_size=128,
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=5e-2,
            training_steps='4ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
Exemple #9
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar_score-resnet_20',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='cifar10',
            batch_size=128,
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='80ep,120ep',
            lr=0.1,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='160ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.9)

        distill_hparams = hparams.DistillHparams(
            teacher_model_name='cifar_resnet_20',
            teacher_ckpt=None,
            teacher_mask=None,
            alpha_ce=1.0,
            alpha_mse=1.0,
            alpha_cls=0.0,
            alpha_cos=0.0,
            temperature=2.0)

        return DistillDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams, distill_hparams)