Exemplo n.º 1
0
    def add_args(parser: argparse.ArgumentParser, defaults: 'LotteryDesc' = None):
        # Add the rewinding/pretraining arguments.
        rewinding_steps = arg_utils.maybe_get_arg('rewinding_steps')
        pretrain = arg_utils.maybe_get_arg('pretrain', boolean_arg=True)

        if rewinding_steps is not None and pretrain: raise ValueError('Cannot set --rewinding_steps and --pretrain')
        pretraining_parser = parser.add_argument_group(
            'Rewinding/Pretraining Arguments', 'Arguments that control how the network is pre-trained')
        LotteryDesc._add_rewinding_argument(pretraining_parser)
        LotteryDesc._add_pretrain_argument(pretraining_parser)

        # Get the proper pruning hparams.
        pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy')
        if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy
        if pruning_strategy:
            pruning_hparams = pruning.registry.get_pruning_hparams(pruning_strategy)
            if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy:
                def_ph = defaults.pruning_hparams
        else:
            pruning_hparams = hparams.PruningHparams
            def_ph = None

        # Add the main arguments.
        hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None)
        hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None)
        hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None)
        pruning_hparams.add_args(parser, defaults=def_ph if defaults else None)

        # Handle pretraining.
        if pretrain:
            if defaults: def_th = replace(defaults.training_hparams, training_steps='0ep')
            hparams.TrainingHparams.add_args(parser, defaults=def_th if defaults else None,
                                             name='Training Hyperparameters for Pretraining', prefix='pretrain')
            hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None,
                                            name='Dataset Hyperparameters for Pretraining', prefix='pretrain')
Exemplo n.º 2
0
 def create_from_args(args: argparse.Namespace):
     runner_name = arg_utils.maybe_get_arg('runner',
                                           positional=True,
                                           position=1)
     branch_name = arg_utils.maybe_get_arg('branch',
                                           positional=True,
                                           position=2)
     return BranchRunner(
         registry.get(runner_name, branch_name).create_from_args(args))
Exemplo n.º 3
0
def main():
    # The welcome message.
    welcome = '=' * 82 + '\nOpenLTH: A Framework for Research on Lottery Tickets and Beyond\n' + '-' * 82

    # Choose an initial command.
    helptext = welcome + "\nChoose a command to run:"
    for name, runner in runner_registry.registered_runners.items():
        helptext += "\n    * {} {} [...] => {}".format(sys.argv[0], name,
                                                       runner.description())
    helptext += '\n' + '=' * 82

    runner_name = arg_utils.maybe_get_arg('subcommand', positional=True)
    if runner_name not in runner_registry.registered_runners:
        print(helptext)
        sys.exit(1)

    # Add the arguments for that command.
    usage = '\n' + welcome + '\n'
    usage += 'open_lth.py {} [...] => {}'.format(
        runner_name,
        runner_registry.get(runner_name).description())
    usage += '\n' + '=' * 82 + '\n'

    parser = argparse.ArgumentParser(usage=usage, conflict_handler='resolve')
    parser.add_argument('subcommand')
    parser.add_argument('--platform',
                        default='local',
                        help='The platform on which to run the job.')
    parser.add_argument('--display_output_location',
                        action='store_true',
                        help='Display the output location for this job.')

    # Get the platform arguments.
    platform_name = arg_utils.maybe_get_arg('platform') or 'local'
    if platform_name and platform_name in platforms.registry.registered_platforms:
        platforms.registry.get(platform_name).add_args(parser)
    else:
        print(f'Invalid platform name: {platform_name}')
        sys.exit(1)

    # Add arguments for the various runners.
    runner_registry.get(runner_name).add_args(parser)

    args = parser.parse_args()
    print(args)
    platform = platforms.registry.get(platform_name).create_from_args(args)

    if args.display_output_location:
        platform.run_job(
            runner_registry.get(runner_name).create_from_args(
                args).display_output_location)
        sys.exit(0)

    platform.run_job(
        runner_registry.get(runner_name).create_from_args(args).run)
Exemplo n.º 4
0
    def add_args(parser: argparse.ArgumentParser,
                 defaults: 'DistillDesc' = None):
        # Get the proper pruning hparams.
        pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy')
        if defaults and not pruning_strategy:
            pruning_strategy = defaults.pruning_hparams.pruning_strategy
        if pruning_strategy:
            pruning_hparams = pruning.registry.get_pruning_hparams(
                pruning_strategy)
            if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy:
                def_ph = defaults.pruning_hparams
            else:
                def_ph = None
        else:
            pruning_hparams = hparams.PruningHparams
            def_ph = None

        # Add the main arguments.
        hparams.DatasetHparams.add_args(
            parser, defaults=defaults.dataset_hparams if defaults else None)
        hparams.ModelHparams.add_args(
            parser, defaults=defaults.model_hparams if defaults else None)
        hparams.TrainingHparams.add_args(
            parser, defaults=defaults.training_hparams if defaults else None)
        pruning_hparams.add_args(parser, defaults=def_ph if defaults else None)
        hparams.DistillHparams.add_args(
            parser, defaults=defaults.distill_hparams if defaults else None)
Exemplo n.º 5
0
    def add_args(parser):
        # Produce help text for selecting the branch.
        helptext = '=' * 82 + '\nOpenLTH: A Library for Research on Lottery Tickets and Beyond\n' + '-' * 82
        runner_name = arg_utils.maybe_get_arg('runner',
                                              positional=True,
                                              position=1)

        # If the runner name is not present.
        if runner_name is None or runner_name not in registry.registered_runners(
        ):
            helptext = '\nChoose a runner on which to branch:\n'
            helptext += '\n'.join([
                f'    * {sys.argv[0]} branch {runner}'
                for runner in registry.registered_runners()
            ])
            helptext += '\n' + '=' * 82
            print(helptext)
            sys.exit(1)

        # If the branch name is not present.
        branch_names = registry.registered_branches(runner_name)
        branch_name = arg_utils.maybe_get_arg('branch',
                                              positional=True,
                                              position=2)
        if branch_name is None or branch_name not in branch_names:
            helptext += '\nChoose a branch to run:'
            for bn in branch_names:
                helptext += "\n    * {} {} {} [...] => {}".format(
                    sys.argv[0], sys.argv[1], bn,
                    registry.get(runner_name, bn).description())
            helptext += '\n' + '=' * 82
            print(helptext)
            sys.exit(1)

        # Add the arguments for the branch.
        parser.add_argument('runner_name', type=str)
        parser.add_argument('branch_name', type=str)
        registry.get(runner_name, branch_name).add_args(parser)
Exemplo n.º 6
0
    def add_args(parser):
        # Produce help text for selecting the branch.
        branch_names = sorted(registry.registered_branches.keys())
        helptext = '=' * 82 + '\nOpenLTH: A Library for Research on Lottery Tickets and Beyond\n' + '-' * 82
        helptext += '\nChoose a branch to run:'
        for branch_name in branch_names:
            helptext += "\n    * {} {} {} [...] => {}".format(
                sys.argv[0], sys.argv[1], branch_name,
                registry.get(branch_name).description())
        helptext += '\n' + '=' * 82

        # Print an error message if appropriate.
        branch_name = arg_utils.maybe_get_arg('subcommand',
                                              positional=True,
                                              position=1)
        if branch_name not in branch_names:
            print(helptext)
            sys.exit(1)

        # Add the arguments for the branch.
        parser.add_argument('branch_name', type=str)
        registry.get(branch_name).add_args(parser)
Exemplo n.º 7
0
def maybe_get_default_hparams():
    default_hparams = arg_utils.maybe_get_arg('default_hparams')
    return models.registry.get_default_hparams(default_hparams) if default_hparams else None