def parse_cmd_arguments(mode='regression_ewc', default=False, argv=None): """Parse command-line arguments for EWC experiments. Args: mode (str): For what script should the parser assemble the set of command-line parameters? Options: - ``'regression_ewc'`` - ``'gmm_ewc'`` - ``'split_mnist_ewc'`` - ``'perm_mnist_ewc'`` - ``'cifar_resnet_ewc'`` default (bool, optional): If ``True``, command-line arguments will be ignored and only the default values will be parsed. argv (list, optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of :code:`sys.argv`. Returns: (argparse.Namespace): The Namespace object containing argument names and values. """ if mode == 'regression_ewc': description = 'Toy regression with tasks trained via EWC' elif mode == 'gmm_ewc': description = 'Probabilistic CL on GMM Datasets via EWC' elif mode == 'split_mnist_ewc': description = 'Probabilistic CL on Split MNIST via EWC' elif mode == 'perm_mnist_ewc': description = 'Probabilistic CL on Permuted MNIST via EWC' elif mode == 'cifar_resnet_ewc': description = 'Probabilistic CL on CIFAR-10/100 via EWC on a Resnet' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) # If needed, add additional parameters. if mode == 'regression_ewc': dout_dir = './out_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=True) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=5001, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=False, dval_batch_size=10000, dval_iter=250) rta.data_args(parser) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir, show_publication_style=True) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=True, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False) elif mode == 'gmm_ewc': dout_dir = './out_gmm_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=False, show_num_classes_per_task=False) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=10000, dval_iter=100) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) elif mode == 'split_mnist_ewc': dout_dir = './out_split_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=5, show_num_classes_per_task=True, dnum_classes_per_task=2) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp', 'lenet', 'resnet', 'wrn'], dmlp_arch='400,400', dlenet_type='mnist_small', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) elif mode == 'perm_mnist_ewc': dout_dir = './out_perm_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=10, show_num_classes_per_task=False) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128, dn_iter=5000, dlr=1e-4, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='1000,1000', show_no_bias=True, show_batchnorm=True, show_no_batchnorm=False, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) pmta.perm_args(parser) elif mode == 'cifar_resnet_ewc': dout_dir = './out_resnet_ewc/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=True, show_num_tasks=True, dnum_tasks=6, show_num_classes_per_task=True, dnum_classes_per_task=10) pcta.extra_cl_args(cl_argroup) train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, depochs=200, show_clip_grad_value=True, show_clip_grad_norm=True) cli.main_net_args(parser, allowed_nets=['resnet', 'wrn', 'iresnet', 'lenet', 'zenke', 'mlp'], dmlp_arch='10,10', dlenet_type='cifar', show_no_bias=True, show_batchnorm=False, show_no_batchnorm=True, show_bn_no_running_stats=True, show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1, show_net_act=True) cli.data_args(parser, show_disable_data_augmentation=True) eval_agroup = cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000, show_val_set_size=True, dval_set_size=0) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) ewc_args(parser, dewc_lambda=1., dn_fisher=-1) pmta.special_train_options(train_agroup, show_soft_targets=False) cl_args(cl_argroup, show_det_multi_head=True) pmta.eval_args(eval_agroup) rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False, dval_sample_size=10) rta.train_args(train_agroup, show_ll_dist_std=False, show_local_reparam_trick=False, show_kl_scale=False, show_radial_bnn=False) rta.miscellaneous_args(misc_agroup, show_mnet_only=False, show_use_logvar_enc=False, show_disable_lrt_test=False, show_mean_only=False, show_during_acc_criterion=True) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) rta.check_invalid_args_general(config) pmta.check_invalid_args_general(config) check_invalid_args_ewc(config) if mode == 'regression_ewc': if config.batchnorm: # Not properly handled in test and eval function! raise NotImplementedError() return config
def parse_cmd_arguments(mode='regression_bbb', default=False, argv=None): """Parse command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - ``'regression_bbb'`` - ``'regression_avb'`` - ``'regression_ssge'`` default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ if mode == 'regression_bbb': description = 'Toy regression with tasks trained by BbB and ' + \ 'protected by a hypernetwork' elif mode == 'regression_avb': description = 'Toy regression with tasks trained by implicit model ' \ 'using AVB and protected by a hypernetwork' elif mode == 'regression_ssge': description = 'Toy regression with tasks trained by implicit model ' \ 'using SSGE and protected by a hypernetwork' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) # Default hnet keyword arguments. hnet_args_kw = { # Function `cli.hnet_args` # Note, the first list element denotes the default hnet. 'allowed_nets': ['hmlp', 'chunked_hmlp', 'hdeconv', 'chunked_hdeconv'], 'dhmlp_arch': '10,10', 'show_cond_emb_size': True, 'dcond_emb_size': 2, 'dchmlp_chunk_size': 64, 'dchunk_emb_size': 8, 'show_use_cond_chunk_embs': True, 'show_net_act': True, 'dnet_act': 'sigmoid', 'show_no_bias': True, 'show_dropout_rate': True, 'ddropout_rate': -1, 'show_specnorm': True, 'show_batchnorm': False, 'show_no_batchnorm': False } # If needed, add additional parameters. if mode == 'regression_bbb': dout_dir = './out_bbb/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True) train_argroup = cli.train_args(parser, show_lr=True, dn_iter=10001, dlr=1e-2, show_clip_grad_value=True, show_clip_grad_norm=True, show_momentum=False, show_adam_beta1=True) cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True) cli.hnet_args(parser, **hnet_args_kw) init_agroup = cli.init_args(parser, custom_option=False) eval_agroup = cli.eval_args(parser, dval_iter=250) data_args(parser) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=True, dout_dir=dout_dir) mc_args(train_argroup, eval_agroup) train_args(train_argroup, show_local_reparam_trick=True, show_radial_bnn=True) cl_args(cl_argroup) init_args(init_agroup) miscellaneous_args(misc_agroup, show_use_logvar_enc=True, show_disable_lrt_test=True, show_mean_only=True) pmta.train_args(train_argroup, show_init_with_prev_emb=False, show_use_prev_post_as_prior=True, show_num_kl_samples=True, show_training_set_size=False) #rtr config parameters for implicit model else: method = 'avb' if mode == 'regression_ssge': method = 'ssge' dout_dir = './out_%s/run_' % method + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True) train_argroup = cli.train_args(parser, show_lr=True, dbatch_size=32, dn_iter=5001, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=True, show_use_adagrad=True, show_epochs=True, show_clip_grad_value=True, show_clip_grad_norm=True) # Main network. cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10', dnet_act='sigmoid', show_no_bias=True) if mode == 'regression_avb': # Discriminator. cli.main_net_args(parser, allowed_nets=['mlp', 'chunked_mlp'], dmlp_arch='10,10', dcmlp_arch='10,10', dcmlp_chunk_arch='10,10', dcmlp_in_cdim=32, dcmlp_out_cdim=8, dcmlp_cemb_dim=8, dnet_act='sigmoid', show_no_bias=True, prefix='dis_', pf_name='discriminator') # Hypernetwork (weight generator). imp_hargs_kw = dict(hnet_args_kw) imp_hargs_kw['show_cond_emb_size'] = False imp_hargs_kw['show_use_cond_chunk_embs'] = False imp_hargs_kw['dcond_emb_size'] = 0 # Not used for implicit hnet! cli.hnet_args(parser, **imp_hargs_kw, prefix='imp_', pf_name='implicit') # Hyper-hypernetwork. hhargs_kw = dict(hnet_args_kw) cli.hnet_args(parser, **hhargs_kw, prefix='hh_', pf_name='hyper-hyper') cli.init_args(parser, custom_option=False, show_hyper_fan_init=True) eval_agroup = cli.eval_args(parser, dval_iter=250) data_args(parser) misc_agroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=True, dout_dir=dout_dir) pmta.special_train_options(train_argroup, show_soft_targets=False) mc_args(train_argroup, eval_agroup) train_args(train_argroup, show_local_reparam_trick=False, show_kl_scale=True) miscellaneous_args(misc_agroup, show_store_during_models=True) pmta.train_args(train_argroup, show_init_with_prev_emb=True, show_use_prev_post_as_prior=True, show_num_kl_samples=True, show_calc_hnet_reg_targets_online=True, show_hnet_reg_batch_size=True, show_training_set_size=False) pmta.ind_posterior_args(train_argroup, show_distill_iter=False) pcta.miscellaneous_args(misc_agroup, show_no_hhnet=True) pcta.imp_args(parser, dlatent_dim=8, show_prior_focused=True) if mode == 'regression_avb': pcta.avb_args(parser) elif mode == 'regression_ssge': pcta.ssge_args(parser) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) check_invalid_args_general(config) pmta.check_invalid_args_general(config) if mode == 'regression_bbb': check_invalid_bbb_args(config) else: pass return config
def parse_cmd_arguments(mode='resnet_cifar', default=False, argv=None): """Parse command-line arguments. Args: mode (str): For what script should the parser assemble the set of command-line parameters? Options: - ``resnet_cifar`` - ``zenke_cifar`` default (bool, optional): If ``True``, command-line arguments will be ignored and only the default values will be parsed. argv (list, optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of :code:`sys.argv`. Returns: (argparse.Namespace): The Namespace object containing argument names and values. """ if mode == 'resnet_cifar': description = 'CIFAR-10/100 CL experiment using a Resnet-32' elif mode == 'zenke_cifar': description = 'CIFAR-10/100 CL experiment using the Zenkenet' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) general_options(parser) if mode == 'resnet_cifar': dout_dir = './out_resnet/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.05, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['resnet'], show_batchnorm=False, show_no_batchnorm=True, show_bn_no_running_stats=True, show_bn_distill_stats=True, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=7000, dhnet_arch='', dtemb_size=32, demb_size=32) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.001, show_epochs=True, depochs=200, dbatch_size=32, dn_iter=2000, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) elif mode == 'zenke_cifar': dout_dir = './out_zenke/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.01, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['zenke'], show_batchnorm=False, show_no_batchnorm=False, show_dropout_rate=True, ddropout_rate=0.25, show_specnorm=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=5500, dhnet_arch='100,150,200', dtemb_size=48, demb_size=80) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.0001, show_epochs=True, depochs=80, dbatch_size=256, dn_iter=2000, show_use_adam=True, dadam_beta1=0.5, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) special_cl_options(cl_group) special_train_options(train_agroup) init_group = cli.init_args(parser, custom_option=True) special_init_options(init_group) cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) ### ... insert additional checks if necessary if config.num_tasks < 1 or config.num_tasks > 11: raise ValueError('Argument "num_tasks" must be between 1 and 11!') if config.cl_scenario != 1: raise NotImplementedError('CIFAR experiments are currently only ' + 'implemented for CL1.') if config.plateau_lr_scheduler and config.epochs == -1: raise ValueError('Flag "plateau_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.lambda_lr_scheduler and config.epochs == -1: raise ValueError('Flag "lambda_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.no_lookahead and config.backprop_dt: raise ValueError('Can\'t activate "no_lookahead" and "backprop_dt" ' + 'simultaneously.') return config
def parse_cmd_arguments(default=False, argv=None): """Parse command-line arguments. Args: default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ description = 'Continual learning on copy task.' parser = argparse.ArgumentParser(description=description) cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True, show_split_head_cl3=False, show_cl_scenario=False, show_num_tasks=True, dnum_tasks=6) cli.train_args(parser, show_lr=True, show_epochs=False, dbatch_size=64, dn_iter=5000, dlr=1e-3, show_clip_grad_value=False, show_clip_grad_norm=True, show_momentum=False, show_adam_beta1=True) seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh') cli.hypernet_args(parser, dhyper_chunks=-1, dhnet_arch='10,10', dtemb_size=2, demb_size=32, dhnet_act='sigmoid') # Args of new hnets. nhnet_args = cli.hnet_args(parser, allowed_nets=[ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ], dhmlp_arch='50,50', show_cond_emb_size=True, dcond_emb_size=32, dchmlp_chunk_size=1000, dchunk_emb_size=32, show_use_cond_chunk_embs=True, dhdeconv_shape='512,512,3', prefix='nh_', pf_name='new edition of a hyper-', show_net_act=True, dnet_act='relu', show_no_bias=True, show_dropout_rate=True, ddropout_rate=-1, show_specnorm=True, show_batchnorm=False, show_no_batchnorm=False) seq.new_hnet_args(nhnet_args) cli.init_args(parser, custom_option=False, show_normal_init=False, show_hyper_fan_init=True) cli.eval_args(parser, dval_iter=250) magroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=False) seq.ewc_args(parser, dewc_lambda=5000., dn_fisher=-1, dtbptt_fisher=-1, show_ts_weighting_fisher=False) seq.si_args(parser, dsi_lambda=1.) seq.context_mod_args(parser, dsparsification_reg_type='l1', dsparsification_reg_strength=1., dcontext_mod_init='constant') seq.miscellaneous_args(magroup, dmask_fraction=0.8, dclassification=True, show_ts_weighting=False, show_use_ce_loss=False, show_early_stopping_thld=True, dearly_stopping_thld=-1) copy_sequence_args(parser) # Replay arguments. rep_args = seq.replay_args(parser, show_all_task_softmax=False) cli.generator_args(rep_args, dlatent_dim=100) cli.main_net_args(parser, allowed_nets=['simple_rnn'], dsrnn_rec_layers='256', dsrnn_pre_fc_layers='', dsrnn_post_fc_layers='', show_net_act=True, dnet_act='tanh', show_no_bias=True, show_dropout_rate=False, show_specnorm=False, show_batchnorm=False, prefix='dec_', pf_name='replay decoder') args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) seq.check_invalid_args_sequential(config) check_invalid_args_sequential(config) if config.train_from_scratch: # FIXME We could get rid of this warning by properly checkpointing and # loading all networks. warnings.warn('When training from scratch, only during accuracies ' + 'make sense. All other outputs should be ignored!') return config
def parse_cmd_arguments(mode='split_smnist', default=False, argv=None): """Parse command-line arguments. Args: mode (str): The mode of the experiment. default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ description = 'Continual learning on SMNIST task.' parser = argparse.ArgumentParser(description=description) dnum_tasks = 1 dnum_classes_per_task = 10 if mode == 'split_smnist': dnum_tasks = 5 dnum_classes_per_task = 2 dval_set_size = 100 if mode == 'smnist': dnum_tasks = 1 dnum_classes_per_task = 10 dval_set_size = 500 cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True, show_split_head_cl3=False, show_cl_scenario=False, show_num_tasks=True, dnum_tasks=dnum_tasks, show_num_classes_per_task=True, dnum_classes_per_task=dnum_classes_per_task) cli.train_args(parser, show_lr=True, show_epochs=False, dbatch_size=64, dn_iter=5000, dlr=1e-3, show_clip_grad_value=False, show_clip_grad_norm=True, show_momentum=False, show_adam_beta1=True) seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh') cli.hypernet_args(parser, dhyper_chunks=-1, dhnet_arch='50,50', dtemb_size=32, demb_size=32, dhnet_act='relu') # Args of new hnets. nhnet_args = cli.hnet_args(parser, allowed_nets=[ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ], dhmlp_arch='50,50', show_cond_emb_size=True, dcond_emb_size=32, dchmlp_chunk_size=1000, dchunk_emb_size=32, show_use_cond_chunk_embs=True, dhdeconv_shape='512,512,3', prefix='nh_', pf_name='new edition of a hyper-', show_net_act=True, dnet_act='relu', show_no_bias=True, show_dropout_rate=True, ddropout_rate=-1, show_specnorm=True, show_batchnorm=False, show_no_batchnorm=False) seq.new_hnet_args(nhnet_args) cli.init_args(parser, custom_option=False, show_normal_init=False, show_hyper_fan_init=True) cli.eval_args(parser, dval_iter=250, show_val_set_size=True, dval_set_size=dval_set_size) magroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=False) seq.ewc_args(parser, dewc_lambda=5000., dn_fisher=-1, dtbptt_fisher=-1, dts_weighting_fisher='last') seq.si_args(parser, dsi_lambda=1.) seq.context_mod_args(parser, dsparsification_reg_type='l1', dsparsification_reg_strength=1., dcontext_mod_init='constant') seq.miscellaneous_args(magroup, dmask_fraction=0.8, dclassification=True, dts_weighting='last', show_use_ce_loss=False) # Replay arguments. rep_args = seq.replay_args(parser) cli.generator_args(rep_args, dlatent_dim=100) cli.main_net_args(parser, allowed_nets=['simple_rnn'], dsrnn_rec_layers='256', dsrnn_pre_fc_layers='', dsrnn_post_fc_layers='', show_net_act=True, dnet_act='tanh', show_no_bias=True, show_dropout_rate=False, show_specnorm=False, show_batchnorm=False, prefix='dec_', pf_name='replay decoder') rep_args.add_argument('--distill_across_time', action='store_true', help='The feature vector of an SMNIST sample ' + 'contains an "end-of-digit" bit, which in the ' + 'original data is only hot for 1 timestep. ' + 'However, in replayed data, this information ' + 'might be blurry. If this option is ' + 'activated, the distillation loss will be ' + 'applied to every timestep weighted by the ' + 'softmaxed "end-of-digit" feature from the ' + 'replayed input. Otherwise, the argmax ' + 'timestep of the "end-of-digit" feature is ' + 'considered.') args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) config.mode = mode ### Check argument values! cli.check_invalid_argument_usage(config) seq.check_invalid_args_sequential(config) if config.train_from_scratch: # FIXME We could get rid of this warning by properly checkpointing and # loading all networks. warnings.warn('When training from scratch, only during accuracies ' + 'make sense. All other outputs should be ignored!') return config
def collect_rp_cmd_arguments(mode='split', description=""): """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) # If needed, add additional parameters. if mode == 'split': dout_dir = './out_split/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=5) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) elif mode == 'perm': dout_dir = './out_permuted/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=10) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=5000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=85000, dhnet_arch='25,25', dtemb_size=24, demb_size=8, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001) train_args_replay(parser, show_emb_lr=True, prefix='dec_', dlr=0.0001, dlr_emb=0.0001, pf_name='Decoder') perm_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser
def collect_rp_cmd_arguments(mode='mlp', description="", argv=0): print('collect_rp_cmd_arguments') """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) agroup = parser.add_argument_group('note') agroup.add_argument('--note', default='small,random,sep-emnist', type=str, required=False, help='(default=%(default)d)') agroup.add_argument('--dis_ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--sim_ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--idrandom', default=0, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--classptask', default=5, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--pdrop1', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--pdrop2', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--dims', default=[], type=list, required=False, help='(default=%(default)f)') agroup.add_argument('--tasks_preserve', type=int, required=False, help='(default=%(default)f)') agroup.add_argument('--upper_bound', action='store_true', help='Train the classifier with "replay" data i.e ' + 'real data. This can be regarded an upper bound.') agroup.add_argument("--num_class_femnist", default=62, type=int, required=False, help='(default=%(default)d)') config = parser.parse_args(args=argv) # If needed, add additional parameters. if mode == 'mlp': # dout_dir = './out_split/run_' + \ # datetime.now().strftime('%Y-%m-%d_%H-%M-%S') dout_dir = './out_split/' + config.note cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.tasks_preserve) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['mlp'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['mlp'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser
def collect_rp_cmd_arguments(mode='split', description="", argv=0): """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) agroup = parser.add_argument_group('note') agroup.add_argument('--note', default='small,random,sep-emnist', type=str, required=False, help='(default=%(default)d)') agroup.add_argument('--ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--idrandom', default=0, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--classptask', default=5, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--pdrop1', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--pdrop2', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--taskcla', default=[], type=list, required=False, help='(default=%(default)f)') config = parser.parse_args(args=argv) # If needed, add additional parameters. if mode == 'split': # dout_dir = './out_split/run_' + \ # datetime.now().strftime('%Y-%m-%d_%H-%M-%S') dout_dir = './out_split/' + config.note cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.ntasks) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=64, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) elif mode == 'perm': dout_dir = './out_permuted/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.ntasks) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=64, dn_iter=5000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=85000, dhnet_arch='25,25', dtemb_size=24, demb_size=8, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001) train_args_replay(parser, show_emb_lr=True, prefix='dec_', dlr=0.0001, dlr_emb=0.0001, pf_name='Decoder') perm_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser