def parse_cmd_arguments(default=False, argv=None): """Parse command-line arguments. Args: default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ description = 'Continual learning on copy task.' parser = argparse.ArgumentParser(description=description) cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True, show_split_head_cl3=False, show_cl_scenario=False, show_num_tasks=True, dnum_tasks=6) cli.train_args(parser, show_lr=True, show_epochs=False, dbatch_size=64, dn_iter=5000, dlr=1e-3, show_clip_grad_value=False, show_clip_grad_norm=True, show_momentum=False, show_adam_beta1=True) seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh') cli.hypernet_args(parser, dhyper_chunks=-1, dhnet_arch='10,10', dtemb_size=2, demb_size=32, dhnet_act='sigmoid') # Args of new hnets. nhnet_args = cli.hnet_args(parser, allowed_nets=[ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ], dhmlp_arch='50,50', show_cond_emb_size=True, dcond_emb_size=32, dchmlp_chunk_size=1000, dchunk_emb_size=32, show_use_cond_chunk_embs=True, dhdeconv_shape='512,512,3', prefix='nh_', pf_name='new edition of a hyper-', show_net_act=True, dnet_act='relu', show_no_bias=True, show_dropout_rate=True, ddropout_rate=-1, show_specnorm=True, show_batchnorm=False, show_no_batchnorm=False) seq.new_hnet_args(nhnet_args) cli.init_args(parser, custom_option=False, show_normal_init=False, show_hyper_fan_init=True) cli.eval_args(parser, dval_iter=250) magroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=False) seq.ewc_args(parser, dewc_lambda=5000., dn_fisher=-1, dtbptt_fisher=-1, show_ts_weighting_fisher=False) seq.si_args(parser, dsi_lambda=1.) seq.context_mod_args(parser, dsparsification_reg_type='l1', dsparsification_reg_strength=1., dcontext_mod_init='constant') seq.miscellaneous_args(magroup, dmask_fraction=0.8, dclassification=True, show_ts_weighting=False, show_use_ce_loss=False, show_early_stopping_thld=True, dearly_stopping_thld=-1) copy_sequence_args(parser) # Replay arguments. rep_args = seq.replay_args(parser, show_all_task_softmax=False) cli.generator_args(rep_args, dlatent_dim=100) cli.main_net_args(parser, allowed_nets=['simple_rnn'], dsrnn_rec_layers='256', dsrnn_pre_fc_layers='', dsrnn_post_fc_layers='', show_net_act=True, dnet_act='tanh', show_no_bias=True, show_dropout_rate=False, show_specnorm=False, show_batchnorm=False, prefix='dec_', pf_name='replay decoder') args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) seq.check_invalid_args_sequential(config) check_invalid_args_sequential(config) if config.train_from_scratch: # FIXME We could get rid of this warning by properly checkpointing and # loading all networks. warnings.warn('When training from scratch, only during accuracies ' + 'make sense. All other outputs should be ignored!') return config
def parse_cmd_arguments(mode='resnet_cifar', default=False, argv=None): """Parse command-line arguments. Args: mode (str): For what script should the parser assemble the set of command-line parameters? Options: - ``resnet_cifar`` - ``zenke_cifar`` default (bool, optional): If ``True``, command-line arguments will be ignored and only the default values will be parsed. argv (list, optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of :code:`sys.argv`. Returns: (argparse.Namespace): The Namespace object containing argument names and values. """ if mode == 'resnet_cifar': description = 'CIFAR-10/100 CL experiment using a Resnet-32' elif mode == 'zenke_cifar': description = 'CIFAR-10/100 CL experiment using the Zenkenet' else: raise ValueError('Mode "%s" unknown.' % (mode)) parser = argparse.ArgumentParser(description=description) general_options(parser) if mode == 'resnet_cifar': dout_dir = './out_resnet/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.05, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['resnet'], show_batchnorm=False, show_no_batchnorm=True, show_bn_no_running_stats=True, show_bn_distill_stats=True, show_bn_no_stats_checkpointing=True, show_specnorm=False, show_dropout_rate=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=7000, dhnet_arch='', dtemb_size=32, demb_size=32) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.001, show_epochs=True, depochs=200, dbatch_size=32, dn_iter=2000, show_use_adam=True, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) elif mode == 'zenke_cifar': dout_dir = './out_zenke/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_group = cli.cl_args(parser, show_beta=True, dbeta=0.01, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_split_head_cl3=False, show_num_tasks=True, dnum_tasks=6) cli.main_net_args(parser, allowed_nets=['zenke'], show_batchnorm=False, show_no_batchnorm=False, show_dropout_rate=True, ddropout_rate=0.25, show_specnorm=False, show_net_act=False) cli.hypernet_args(parser, dhyper_chunks=5500, dhnet_arch='100,150,200', dtemb_size=48, demb_size=80) cli.data_args(parser, show_disable_data_augmentation=True) train_agroup = cli.train_args(parser, show_lr=True, dlr=0.0001, show_epochs=True, depochs=80, dbatch_size=256, dn_iter=2000, show_use_adam=True, dadam_beta1=0.5, show_use_rmsprop=True, show_use_adadelta=False, show_use_adagrad=False, show_clip_grad_value=False, show_clip_grad_norm=False) special_cl_options(cl_group) special_train_options(train_agroup) init_group = cli.init_args(parser, custom_option=True) special_init_options(init_group) cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=False, no_cuda=False, dout_dir=dout_dir) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) ### Check argument values! cli.check_invalid_argument_usage(config) ### ... insert additional checks if necessary if config.num_tasks < 1 or config.num_tasks > 11: raise ValueError('Argument "num_tasks" must be between 1 and 11!') if config.cl_scenario != 1: raise NotImplementedError('CIFAR experiments are currently only ' + 'implemented for CL1.') if config.plateau_lr_scheduler and config.epochs == -1: raise ValueError('Flag "plateau_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.lambda_lr_scheduler and config.epochs == -1: raise ValueError('Flag "lambda_lr_scheduler" can only be used if ' + '"epochs" was set.') if config.no_lookahead and config.backprop_dt: raise ValueError('Can\'t activate "no_lookahead" and "backprop_dt" ' + 'simultaneously.') return config
def parse_cmd_arguments(mode='split_smnist', default=False, argv=None): """Parse command-line arguments. Args: mode (str): The mode of the experiment. default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ description = 'Continual learning on SMNIST task.' parser = argparse.ArgumentParser(description=description) dnum_tasks = 1 dnum_classes_per_task = 10 if mode == 'split_smnist': dnum_tasks = 5 dnum_classes_per_task = 2 dval_set_size = 100 if mode == 'smnist': dnum_tasks = 1 dnum_classes_per_task = 10 dval_set_size = 500 cli.cl_args(parser, show_beta=True, dbeta=0.005, show_from_scratch=True, show_multi_head=True, show_split_head_cl3=False, show_cl_scenario=False, show_num_tasks=True, dnum_tasks=dnum_tasks, show_num_classes_per_task=True, dnum_classes_per_task=dnum_classes_per_task) cli.train_args(parser, show_lr=True, show_epochs=False, dbatch_size=64, dn_iter=5000, dlr=1e-3, show_clip_grad_value=False, show_clip_grad_norm=True, show_momentum=False, show_adam_beta1=True) seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh') cli.hypernet_args(parser, dhyper_chunks=-1, dhnet_arch='50,50', dtemb_size=32, demb_size=32, dhnet_act='relu') # Args of new hnets. nhnet_args = cli.hnet_args(parser, allowed_nets=[ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ], dhmlp_arch='50,50', show_cond_emb_size=True, dcond_emb_size=32, dchmlp_chunk_size=1000, dchunk_emb_size=32, show_use_cond_chunk_embs=True, dhdeconv_shape='512,512,3', prefix='nh_', pf_name='new edition of a hyper-', show_net_act=True, dnet_act='relu', show_no_bias=True, show_dropout_rate=True, ddropout_rate=-1, show_specnorm=True, show_batchnorm=False, show_no_batchnorm=False) seq.new_hnet_args(nhnet_args) cli.init_args(parser, custom_option=False, show_normal_init=False, show_hyper_fan_init=True) cli.eval_args(parser, dval_iter=250, show_val_set_size=True, dval_set_size=dval_set_size) magroup = cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=True, show_publication_style=False) seq.ewc_args(parser, dewc_lambda=5000., dn_fisher=-1, dtbptt_fisher=-1, dts_weighting_fisher='last') seq.si_args(parser, dsi_lambda=1.) seq.context_mod_args(parser, dsparsification_reg_type='l1', dsparsification_reg_strength=1., dcontext_mod_init='constant') seq.miscellaneous_args(magroup, dmask_fraction=0.8, dclassification=True, dts_weighting='last', show_use_ce_loss=False) # Replay arguments. rep_args = seq.replay_args(parser) cli.generator_args(rep_args, dlatent_dim=100) cli.main_net_args(parser, allowed_nets=['simple_rnn'], dsrnn_rec_layers='256', dsrnn_pre_fc_layers='', dsrnn_post_fc_layers='', show_net_act=True, dnet_act='tanh', show_no_bias=True, show_dropout_rate=False, show_specnorm=False, show_batchnorm=False, prefix='dec_', pf_name='replay decoder') rep_args.add_argument('--distill_across_time', action='store_true', help='The feature vector of an SMNIST sample ' + 'contains an "end-of-digit" bit, which in the ' + 'original data is only hot for 1 timestep. ' + 'However, in replayed data, this information ' + 'might be blurry. If this option is ' + 'activated, the distillation loss will be ' + 'applied to every timestep weighted by the ' + 'softmaxed "end-of-digit" feature from the ' + 'replayed input. Otherwise, the argmax ' + 'timestep of the "end-of-digit" feature is ' + 'considered.') args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" ' + 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) config.mode = mode ### Check argument values! cli.check_invalid_argument_usage(config) seq.check_invalid_args_sequential(config) if config.train_from_scratch: # FIXME We could get rid of this warning by properly checkpointing and # loading all networks. warnings.warn('When training from scratch, only during accuracies ' + 'make sense. All other outputs should be ignored!') return config
def collect_rp_cmd_arguments(mode='split', description=""): """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) # If needed, add additional parameters. if mode == 'split': dout_dir = './out_split/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=5) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) elif mode == 'perm': dout_dir = './out_permuted/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=10) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=5000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=85000, dhnet_arch='25,25', dtemb_size=24, demb_size=8, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001) train_args_replay(parser, show_emb_lr=True, prefix='dec_', dlr=0.0001, dlr_emb=0.0001, pf_name='Decoder') perm_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser
def collect_rp_cmd_arguments(mode='mlp', description="", argv=0): print('collect_rp_cmd_arguments') """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) agroup = parser.add_argument_group('note') agroup.add_argument('--note', default='small,random,sep-emnist', type=str, required=False, help='(default=%(default)d)') agroup.add_argument('--dis_ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--sim_ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--idrandom', default=0, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--classptask', default=5, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--pdrop1', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--pdrop2', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--dims', default=[], type=list, required=False, help='(default=%(default)f)') agroup.add_argument('--tasks_preserve', type=int, required=False, help='(default=%(default)f)') agroup.add_argument('--upper_bound', action='store_true', help='Train the classifier with "replay" data i.e ' + 'real data. This can be regarded an upper bound.') agroup.add_argument("--num_class_femnist", default=62, type=int, required=False, help='(default=%(default)d)') config = parser.parse_args(args=argv) # If needed, add additional parameters. if mode == 'mlp': # dout_dir = './out_split/run_' + \ # datetime.now().strftime('%Y-%m-%d_%H-%M-%S') dout_dir = './out_split/' + config.note cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.tasks_preserve) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=128, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['mlp'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['mlp'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser
def parse_cmd_arguments(mode='split', default=False, argv=None): """Parse command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" default (optional): If True, command-line arguments will be ignored and only the default values will be parsed. argv (optional): If provided, it will be treated as a list of command- line argument that is passed to the parser in place of sys.argv. Returns: The Namespace object containing argument names and values. """ if mode == 'split': description = 'Training classifier sequentially on splitMNIST' elif mode == 'perm': description = 'Training classifier sequentially on permutedMNIST' else: raise Exception('Mode "%s" unknown.' % (mode)) parser = collect_rp_cmd_arguments(mode=mode, description=description, argv=argv) # If needed, add additional parameters. if mode == 'split': cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='class_', pf_name='classifier') cli.hypernet_args(parser, dhyper_chunks=42000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='class_', pf_name='classifier', dhnet_act='relu') train_args_replay(parser, show_emb_lr=True, prefix='class_',dlr=0.001, dlr_emb=0.001, pf_name='Classifier') elif mode == 'perm': cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='class_', pf_name='classifier') cli.hypernet_args(parser, dhyper_chunks=78000, dhnet_arch='25,25', dtemb_size=24, demb_size=8, prefix='class_', pf_name='classifier', dhnet_act='relu') train_args_replay(parser,show_emb_lr=True, prefix='class_',dlr=0.0001, dlr_emb=0.0001, pf_name='Classifier') cl_arguments_general(parser) cl_arguments_classificiation(parser) args = None if argv is not None: if default: warnings.warn('Provided "argv" will be ignored since "default" '+ 'option was turned on.') args = argv if default: args = [] config = parser.parse_args(args=args) print('args num_tasks: ',config.num_tasks) ### Check argument values! cli.check_invalid_argument_usage(config) # if mode == 'split': # if config.num_tasks > 5: # raise ValueError('SplitMNIST may have maximally 5 tasks.') return config
def collect_rp_cmd_arguments(mode='split', description="", argv=0): """Collect command-line arguments. Args: mode: For what script should the parser assemble the set of command-line parameters? Options: - "split" - "perm" Returns: The Namespace object containing argument names and values. """ parser = argparse.ArgumentParser(description=description) agroup = parser.add_argument_group('note') agroup.add_argument('--note', default='small,random,sep-emnist', type=str, required=False, help='(default=%(default)d)') agroup.add_argument('--ntasks', default=10, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--idrandom', default=0, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--classptask', default=5, type=int, required=False, help='(default=%(default)d)') agroup.add_argument('--pdrop1', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--pdrop2', default=-1, type=float, required=False, help='(default=%(default)f)') agroup.add_argument('--taskcla', default=[], type=list, required=False, help='(default=%(default)f)') config = parser.parse_args(args=argv) # If needed, add additional parameters. if mode == 'split': # dout_dir = './out_split/run_' + \ # datetime.now().strftime('%Y-%m-%d_%H-%M-%S') dout_dir = './out_split/' + config.note cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.ntasks) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=64, dn_iter=2000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='400,400', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=50000, dhnet_arch='10,10', dtemb_size=96, demb_size=96, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=False, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='encoder') train_args_replay(parser, show_emb_lr=True, prefix='dec_', pf_name='decoder') split_args(parser) elif mode == 'perm': dout_dir = './out_permuted/run_' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') cl_argroup = cli.cl_args(parser, show_beta=False, show_from_scratch=True, show_multi_head=False, show_cl_scenario=True, show_num_tasks=True, dnum_tasks=config.ntasks) train_argroup = cli.train_args(parser, show_lr=False, dbatch_size=64, dn_iter=5000, show_epochs=True) cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='enc_', pf_name='encoder') cli.main_net_args(parser, allowed_nets=['fc'], dfc_arch='1000,1000', dnet_act='relu', prefix='dec_', pf_name='decoder') cli.hypernet_args(parser, dhyper_chunks=85000, dhnet_arch='25,25', dtemb_size=24, demb_size=8, prefix='rp_', pf_name='replay', dhnet_act='elu') cli.init_args(parser, custom_option=False) cli.miscellaneous_args(parser, big_data=False, synthetic_data=True, show_plots=True, no_cuda=False, dout_dir=dout_dir) cli.generator_args(parser, dlatent_dim=100) cli.eval_args(parser, dval_iter=1000) train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001) train_args_replay(parser, show_emb_lr=True, prefix='dec_', dlr=0.0001, dlr_emb=0.0001, pf_name='Decoder') perm_args(parser) cl_arguments_replay(parser) cl_arguments_general(parser) data_args(parser) return parser