コード例 #1
0
ファイル: runner.py プロジェクト: sahibsin/Pruning
 def create_from_args(args: argparse.Namespace) -> 'LotteryRunner':
     d = LotteryDesc.create_from_args(args)
     if args.weight_save_steps:
         weight_save_steps = [d.str_to_step(s) for s in args.weight_save_steps.split(',')]
     else:
         weight_save_steps = []
     return LotteryRunner(args.replicate, args.levels, LotteryDesc.create_from_args(args),
                          not args.quiet, not args.evaluate_only_at_end, weight_save_steps)
コード例 #2
0
ファイル: runner.py プロジェクト: sahibsin/Pruning
    def add_args(parser: argparse.ArgumentParser) -> None:
        # Get preliminary information.
        defaults = shared_args.maybe_get_default_hparams()

        # Add the job arguments.
        shared_args.JobArgs.add_args(parser)
        lottery_parser = parser.add_argument_group(
            'Lottery Ticket Hyperparameters', 'Hyperparameters that control the lottery ticket process.')
        LotteryRunner._add_levels_argument(lottery_parser)
        LotteryDesc.add_args(parser, defaults)
コード例 #3
0
ファイル: vqe_tfim1d.py プロジェクト: xucheny/open_lth
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='vqe_tfim1d_8_100_0.9',
            model_init='hva_normal',
            batchnorm_init='uniform' # Does not matter, we don't have module named 'BatchNorm2d'
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='placeholder',
            batch_size=3
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=1e-3,
            weight_decay=1e-4,
            training_steps='3000ep',
        )

        pruning_hparams = parallel_sparse_global.PruningHparams(
            pruning_strategy='parallel_sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight', #TODO: what if we have nothing to ignore?
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams)
コード例 #4
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar_vgg_16',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(dataset_name='cifar10',
                                                 batch_size=128)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='80ep,120ep',
            lr=0.1,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='160ep')

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight')

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #5
0
ファイル: eigensolver.py プロジェクト: xucheny/open_lth
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='eigensolver_xxz1d-4-1.0_30_10',
            model_init='kaiming_normal',
            batchnorm_init=
            'uniform'  # Does not matter, we don't have module named 'BatchNorm2d'
        )

        dataset_hparams = hparams.DatasetHparams(dataset_name='placeholder',
                                                 batch_size=3)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            lr=1e-1,
            #weight_decay=1e-4,
            training_steps='100ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #6
0
ファイル: cifar10_simplecnn.py プロジェクト: npnbpb/open_lth
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar10_simplecnn_16_32',
            model_init='kaiming_normal',
            batchnorm_init='uniform'
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='cifar10',
            batch_size=128
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=1e-2,
            training_steps='5ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams)
コード例 #7
0
    def default_hparams():
        """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet.

        To get these results with a smaller batch size, scale the batch size linearly.
        That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc.
        """

        model_hparams = hparams.ModelHparams(
            model_name='imagenet_resnet_50',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='imagenet',
            batch_size=1024,
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='30ep,60ep,80ep',
            lr=0.4,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='90ep',
            warmup_steps='5ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #8
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(model_name='mnist_lenet5',
                                             model_init='kaiming_normal',
                                             batchnorm_init='uniform')

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='mnist',
            batch_size=128,
            # resize_input=False
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=0.001,
            training_steps='20ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            # pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #9
0
ファイル: imagenet_resnet.py プロジェクト: sahibsin/Pruning
    def default_hparams(model_name):
        """These hyperparameters will reach 76.1% top-1 accuracy on ImageNet and XX.X% top-1 accuracy on TinyImageNet.

        To get these results with a smaller batch size, scale the learning rate linearly.
        That is, batch size 512 -> lr 0.2, 256 -> 0.1, etc.
        """

        # Model hyperparameters.
        model_hparams = hparams.ModelHparams(
            model_name=model_name,
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        # Dataset hyperparameters.
        if model_name.startswith('imagenet'):
            dataset_hparams = hparams.DatasetHparams(dataset_name='imagenet',
                                                     batch_size=1024)
        elif model_name.startswith('tinyimagenet'):
            dataset_hparams = hparams.DatasetHparams(
                dataset_name='tinyimagenet', batch_size=256)

        # Training hyperparameters.
        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='30ep,60ep,80ep',
            lr=0.4,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='90ep',
            warmup_steps='5ep',
        )

        if model_name.startswith('tinyimagenet'):
            training_hparams.training_steps = '200ep'
            training_hparams.milestone_steps = '100ep,150ep'
            training_hparams.lr = 0.2

        # Pruning hyperparameters.
        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #10
0
ファイル: mnist_lenet.py プロジェクト: sahibsin/Pruning
    def default_hparams(model_name):
        model_hparams = hparams.ModelHparams(model_name=model_name,
                                             model_init='kaiming_normal',
                                             batchnorm_init='uniform')

        dataset_hparams = hparams.DatasetHparams(dataset_name='mnist',
                                                 batch_size=128)

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            lr=0.1,
            training_steps='16ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            pruning_layers_to_ignore='fc.weight',
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #11
0
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='fashionmnist_resnet_20',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='fashionmnist',
            batch_size=128,
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='adam',
            lr=5e-2,
            training_steps='4ep',
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global', pruning_fraction=0.2)

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams,
                           pruning_hparams)
コード例 #12
0
ファイル: runner.py プロジェクト: liuguoyou/ElasticLTH
 def create_from_args(args: argparse.Namespace) -> 'LotteryRunner':
     return LotteryRunner(args.replicate, args.levels,
                          LotteryDesc.create_from_args(args),
                          not args.quiet, not args.evaluate_only_at_end)
コード例 #13
0
ファイル: test_desc.py プロジェクト: liuguoyou/ElasticLTH
    def test_hashes_regression(self):
        """Test that the hashes are fixed and repeatable.

        None of these hashes should change through any modification you make to the code.
        You can avoid this by ensuring you don't change existing hyperparameters and only
        add new hyperparameters that have default values.
        If they do, you will no longer be able to access your old models.
        """
        desc = LotteryDesc(
            dataset_hparams=hparams.DatasetHparams('cifar10', 128),
            model_hparams=hparams.ModelHparams('cifar_resnet_20', 'kaiming_normal', 'uniform'),
            training_hparams=hparams.TrainingHparams('sgd', 0.1, '160ep'),
            pruning_hparams=Strategy.get_pruning_hparams()('sparse_global')
        )
        self.assertEqual(desc.hashname, 'lottery_da8fd50859ba6d59aceca9d50ebcbf7e')

        with self.subTest():
            desc.training_hparams.momentum = 0.9
            self.assertEqual(desc.hashname, 'lottery_028eb999ecd1980cd012589829c945a3')

        with self.subTest():
            desc.training_hparams.milestone_steps = '80ep,120ep'
            desc.training_hparams.gamma = 0.1
            self.assertEqual(desc.hashname, 'lottery_e696cbf42d8758b8afdf2a16fad1de15')

        with self.subTest():
            desc.training_hparams.weight_decay = 1e-4
            self.assertEqual(desc.hashname, 'lottery_93bc65d66dfa64ffaf2a0ab105433a2c')

        with self.subTest():
            desc.training_hparams.warmup_steps = '20ep'
            self.assertEqual(desc.hashname, 'lottery_4e7b9ee929e8b1c911c5295233e6828f')

        with self.subTest():
            desc.training_hparams.data_order_seed = 0
            self.assertEqual(desc.hashname, 'lottery_d51482c0d378de4cc71b87b38df2ea84')

        with self.subTest():
            desc.dataset_hparams.do_not_augment = True
            self.assertEqual(desc.hashname, 'lottery_231b1efe748045875f738d860f4cb547')

        with self.subTest():
            desc.dataset_hparams.transformation_seed = 0
            self.assertEqual(desc.hashname, 'lottery_4dfd57a481be9a2d840f7ad5d1e6f5f0')

        with self.subTest():
            desc.dataset_hparams.subsample_fraction = 0.5
            self.assertEqual(desc.hashname, 'lottery_59ea6f2fab91a9515ae4bccd5de70878')

        with self.subTest():
            desc.dataset_hparams.random_labels_fraction = 0.7
            self.assertEqual(desc.hashname, 'lottery_8b59e5a4d5d72575f1fba67b476899fc')

        with self.subTest():
            desc.dataset_hparams.unsupervised_labels = 'rotation'
            self.assertEqual(desc.hashname, 'lottery_81f340e038ec29ffa9f858d9a8762211')

        with self.subTest():
            desc.dataset_hparams.blur_factor = 4
            self.assertEqual(desc.hashname, 'lottery_4e78e2719ef5c16ba3e0444bc10dfb08')

        with self.subTest():
            desc.model_hparams.batchnorm_frozen = True
            self.assertEqual(desc.hashname, 'lottery_8db76b3c3a08c4a1643f066768ff4e56')

        with self.subTest():
            desc.model_hparams.batchnorm_frozen = False
            desc.model_hparams.others_frozen = True
            self.assertEqual(desc.hashname, 'lottery_3a0f8b86c0813802537aea2ebe723051')

        with self.subTest():
            desc.model_hparams.others_frozen = False
            desc.pruning_hparams.pruning_layers_to_ignore = 'fc.weight'
            self.assertEqual(desc.hashname, 'lottery_d74aca8d02109ec0816739c2f7057433')
コード例 #14
0
ファイル: desc.py プロジェクト: liuguoyou/ElasticLTH
 def create_from_args(cls, args: argparse.Namespace):
     return BranchDesc(LotteryDesc.create_from_args(args),
                       BranchHparams.create_from_args(args))
コード例 #15
0
ファイル: desc.py プロジェクト: liuguoyou/ElasticLTH
 def add_args(parser: argparse.ArgumentParser,
              defaults: LotteryDesc = None):
     LotteryDesc.add_args(parser, defaults)
     BranchHparams.add_args(parser)