def parse_cmd_arguments(mode='regression_ewc', default=False, argv=None): """Parse command-line arguments for EWC experiments. Args: mode (str): For what script should the parser assemble the set of command-line parameters? Options: - ``'regression_ewc'`` - ``'gmm_ewc'`` - ``'split_mnist_ewc'`` - ``'perm_mnist_ewc'`` - ``'cifar_resnet_ewc'`` default (bool, optional): If ``True``, command-line arguments will be ignored and only the default values will be parsed. argv (list, optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of :code:`sys.argv`. Returns: (argparse.Namespace): The Namespace object containing argument names and values. """ if mode == 'regression_ewc': description = 'Toy regression with tasks trained via EWC' elif mode == 'gmm_ewc': description = 'Probabilistic CL on GMM Datasets via EWC' elif mode == 'split_mnist_ewc': description = 'Probabilistic CL on Split MNIST via EWC' elif mode == 'perm_mnist_ewc': description = 'Probabilistic CL on Permuted MNIST via EWC' elif mode == 'cifar_resnet_ewc': description = 'Probabilistic CL on CIFAR-10/100 via EWC on a Resnet' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) # If needed, add additional parameters. if mode == 'regression_ewc': dout_dir = './out_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=True) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=5001, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=False, dval_batch_size=10000, dval_iter=250) rta.data_args(parser) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir, show_publication_style=True) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=True, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False) elif mode == 'gmm_ewc': dout_dir = './out_gmm_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=False, show_num_classes_per_task=False) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=10000, dval_iter=100) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) elif mode == 'split_mnist_ewc': dout_dir = './out_split_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=5, show_num_classes_per_task=True, dnum_classes_per_task=2) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp', 'lenet', 'resnet', 'wrn'], dmlp_arch='400,400', dlenet_type='mnist_small', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) elif mode == 'perm_mnist_ewc': dout_dir = './out_perm_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=10, show_num_classes_per_task=False) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128, dn_iter=5000, dlr=1e-4, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='1000,1000', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) pmta.perm_args(parser) elif mode == 'cifar_resnet_ewc': dout_dir = './out_resnet_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=6, show_num_classes_per_task=True, dnum_classes_per_task=10) pcta.extra_cl_args(cl_argroup) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, depochs=200, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['resnet', 'wrn', 'iresnet', 'lenet', 'zenke', 'mlp'], dmlp_arch='10,10', dlenet_type='cifar', show_no_bias=True, show_batchnorm=False, show_no_batchnorm=True, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) cli.data_args(parser, show_disable_data_augmentation=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) rta.check_invalid_args_general(config) pmta.check_invalid_args_general(config) check_invalid_args_ewc(config) if mode == 'regression_ewc': if config.batchnorm: # Not properly handled in test and eval function! raise NotImplementedError() return config
def parse_cmd_arguments(mode='resnet_cifar', default=False, argv=None): """Parse command-line arguments. Args: mode (str): For what script should the parser assemble the set of command-line parameters? Options: - ``resnet_cifar`` - ``zenke_cifar`` default (bool, optional): If ``True``, command-line arguments will be ignored and only the default values will be parsed. argv (list, optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of :code:`sys.argv`. Returns: (argparse.Namespace): The Namespace object containing argument names and values. """ if mode == 'resnet_cifar': description = 'CIFAR-10/100 CL experiment using a Resnet-32' elif mode == 'zenke_cifar': description = 'CIFAR-10/100 CL experiment using the Zenkenet' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) general_options(parser) if mode == 'resnet_cifar': dout_dir = './out_resnet/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.05, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['resnet'], show_batchnorm=False, show_no_batchnorm=True, show_bn_no_running_stats=True, show_bn_distill_stats=True, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=7000, dhnet_arch='', dtemb_size=32, demb_size=32) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.001, show_epochs=True, depochs=200, dbatch_size=32, dn_iter=2000, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) elif mode == 'zenke_cifar': dout_dir = './out_zenke/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.01, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['zenke'], show_batchnorm=False, show_no_batchnorm=False, show_dropout_rate=True, ddropout_rate=0.25, show_specnorm=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=5500, dhnet_arch='100,150,200', dtemb_size=48, demb_size=80) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.0001, show_epochs=True, depochs=80, dbatch_size=256, dn_iter=2000, show_use_adam=True, dadam_beta1=0.5, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) special_cl_options(cl_group) special_train_options(train_agroup) init_group = cli.init_args(parser, custom_option=True) special_init_options(init_group) cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) ### ... insert additional checks if necessary if config.num_tasks < 1 or config.num_tasks > 11: raise ValueError('Argument "num_tasks" must be between 1 and 11!') if config.cl_scenario != 1: raise NotImplementedError('CIFAR experiments are currently only ' + 'implemented for CL1.') if config.plateau_lr_scheduler and config.epochs == -1: raise ValueError('Flag "plateau_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.lambda_lr_scheduler and config.epochs == -1: raise ValueError('Flag "lambda_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.no_lookahead and config.backprop_dt: raise ValueError('Can\'t activate "no_lookahead" and "backprop_dt" ' + 'simultaneously.') return config