def add_args(parser: argparse.ArgumentParser, defaults: 'LotteryDesc' = None): # Add the rewinding/pretraining arguments. rewinding_steps = arg_utils.maybe_get_arg('rewinding_steps') pretrain = arg_utils.maybe_get_arg('pretrain', boolean_arg=True) if rewinding_steps is not None and pretrain: raise ValueError('Cannot set --rewinding_steps and --pretrain') pretraining_parser = parser.add_argument_group( 'Rewinding/Pretraining Arguments', 'Arguments that control how the network is pre-trained') LotteryDesc._add_rewinding_argument(pretraining_parser) LotteryDesc._add_pretrain_argument(pretraining_parser) # Get the proper pruning hparams. pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy') if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy if pruning_strategy: pruning_hparams = pruning.registry.get_pruning_hparams(pruning_strategy) if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy: def_ph = defaults.pruning_hparams else: pruning_hparams = hparams.PruningHparams def_ph = None # Add the main arguments. hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None) hparams.ModelHparams.add_args(parser, defaults=defaults.model_hparams if defaults else None) hparams.TrainingHparams.add_args(parser, defaults=defaults.training_hparams if defaults else None) pruning_hparams.add_args(parser, defaults=def_ph if defaults else None) # Handle pretraining. if pretrain: if defaults: def_th = replace(defaults.training_hparams, training_steps='0ep') hparams.TrainingHparams.add_args(parser, defaults=def_th if defaults else None, name='Training Hyperparameters for Pretraining', prefix='pretrain') hparams.DatasetHparams.add_args(parser, defaults=defaults.dataset_hparams if defaults else None, name='Dataset Hyperparameters for Pretraining', prefix='pretrain')
def create_from_args(args: argparse.Namespace): runner_name = arg_utils.maybe_get_arg('runner', positional=True, position=1) branch_name = arg_utils.maybe_get_arg('branch', positional=True, position=2) return BranchRunner( registry.get(runner_name, branch_name).create_from_args(args))
def main(): # The welcome message. welcome = '=' * 82 + '\nOpenLTH: A Framework for Research on Lottery Tickets and Beyond\n' + '-' * 82 # Choose an initial command. helptext = welcome + "\nChoose a command to run:" for name, runner in runner_registry.registered_runners.items(): helptext += "\n * {} {} [...] => {}".format(sys.argv[0], name, runner.description()) helptext += '\n' + '=' * 82 runner_name = arg_utils.maybe_get_arg('subcommand', positional=True) if runner_name not in runner_registry.registered_runners: print(helptext) sys.exit(1) # Add the arguments for that command. usage = '\n' + welcome + '\n' usage += 'open_lth.py {} [...] => {}'.format( runner_name, runner_registry.get(runner_name).description()) usage += '\n' + '=' * 82 + '\n' parser = argparse.ArgumentParser(usage=usage, conflict_handler='resolve') parser.add_argument('subcommand') parser.add_argument('--platform', default='local', help='The platform on which to run the job.') parser.add_argument('--display_output_location', action='store_true', help='Display the output location for this job.') # Get the platform arguments. platform_name = arg_utils.maybe_get_arg('platform') or 'local' if platform_name and platform_name in platforms.registry.registered_platforms: platforms.registry.get(platform_name).add_args(parser) else: print(f'Invalid platform name: {platform_name}') sys.exit(1) # Add arguments for the various runners. runner_registry.get(runner_name).add_args(parser) args = parser.parse_args() print(args) platform = platforms.registry.get(platform_name).create_from_args(args) if args.display_output_location: platform.run_job( runner_registry.get(runner_name).create_from_args( args).display_output_location) sys.exit(0) platform.run_job( runner_registry.get(runner_name).create_from_args(args).run)
def add_args(parser: argparse.ArgumentParser, defaults: 'DistillDesc' = None): # Get the proper pruning hparams. pruning_strategy = arg_utils.maybe_get_arg('pruning_strategy') if defaults and not pruning_strategy: pruning_strategy = defaults.pruning_hparams.pruning_strategy if pruning_strategy: pruning_hparams = pruning.registry.get_pruning_hparams( pruning_strategy) if defaults and defaults.pruning_hparams.pruning_strategy == pruning_strategy: def_ph = defaults.pruning_hparams else: def_ph = None else: pruning_hparams = hparams.PruningHparams def_ph = None # Add the main arguments. hparams.DatasetHparams.add_args( parser, defaults=defaults.dataset_hparams if defaults else None) hparams.ModelHparams.add_args( parser, defaults=defaults.model_hparams if defaults else None) hparams.TrainingHparams.add_args( parser, defaults=defaults.training_hparams if defaults else None) pruning_hparams.add_args(parser, defaults=def_ph if defaults else None) hparams.DistillHparams.add_args( parser, defaults=defaults.distill_hparams if defaults else None)
def add_args(parser): # Produce help text for selecting the branch. helptext = '=' * 82 + '\nOpenLTH: A Library for Research on Lottery Tickets and Beyond\n' + '-' * 82 runner_name = arg_utils.maybe_get_arg('runner', positional=True, position=1) # If the runner name is not present. if runner_name is None or runner_name not in registry.registered_runners( ): helptext = '\nChoose a runner on which to branch:\n' helptext += '\n'.join([ f' * {sys.argv[0]} branch {runner}' for runner in registry.registered_runners() ]) helptext += '\n' + '=' * 82 print(helptext) sys.exit(1) # If the branch name is not present. branch_names = registry.registered_branches(runner_name) branch_name = arg_utils.maybe_get_arg('branch', positional=True, position=2) if branch_name is None or branch_name not in branch_names: helptext += '\nChoose a branch to run:' for bn in branch_names: helptext += "\n * {} {} {} [...] => {}".format( sys.argv[0], sys.argv[1], bn, registry.get(runner_name, bn).description()) helptext += '\n' + '=' * 82 print(helptext) sys.exit(1) # Add the arguments for the branch. parser.add_argument('runner_name', type=str) parser.add_argument('branch_name', type=str) registry.get(runner_name, branch_name).add_args(parser)
def add_args(parser): # Produce help text for selecting the branch. branch_names = sorted(registry.registered_branches.keys()) helptext = '=' * 82 + '\nOpenLTH: A Library for Research on Lottery Tickets and Beyond\n' + '-' * 82 helptext += '\nChoose a branch to run:' for branch_name in branch_names: helptext += "\n * {} {} {} [...] => {}".format( sys.argv[0], sys.argv[1], branch_name, registry.get(branch_name).description()) helptext += '\n' + '=' * 82 # Print an error message if appropriate. branch_name = arg_utils.maybe_get_arg('subcommand', positional=True, position=1) if branch_name not in branch_names: print(helptext) sys.exit(1) # Add the arguments for the branch. parser.add_argument('branch_name', type=str) registry.get(branch_name).add_args(parser)
def maybe_get_default_hparams(): default_hparams = arg_utils.maybe_get_arg('default_hparams') return models.registry.get_default_hparams(default_hparams) if default_hparams else None