Exemplo n.º 1
0
def parse_cmd_arguments(mode='regression_ewc', default=False, argv=None):
    """Parse command-line arguments for EWC experiments.

    Args:
        mode (str): For what script should the parser assemble the set of
            command-line parameters? Options:

                - ``'regression_ewc'``
                - ``'gmm_ewc'``
                - ``'split_mnist_ewc'``
                - ``'perm_mnist_ewc'``
                - ``'cifar_resnet_ewc'``
        default (bool, optional): If ``True``, command-line arguments will be
            ignored and only the default values will be parsed.
        argv (list, optional): If provided, it will be treated as a list of
            command- line argument that is passed to the parser in place of
            :code:`sys.argv`.

    Returns:
        (argparse.Namespace): The Namespace object containing argument names and
            values.
    """
    if mode == 'regression_ewc':
        description = 'Toy regression with tasks trained via EWC'
    elif mode == 'gmm_ewc':
        description = 'Probabilistic CL on GMM Datasets via EWC'
    elif mode == 'split_mnist_ewc':
        description = 'Probabilistic CL on Split MNIST via EWC'
    elif mode == 'perm_mnist_ewc':
        description = 'Probabilistic CL on Permuted MNIST via EWC'
    elif mode == 'cifar_resnet_ewc':
        description = 'Probabilistic CL on CIFAR-10/100 via EWC on a Resnet'
    else:
        raise ValueError('Mode "%s" unknown.' % (mode))

    parser = argparse.ArgumentParser(description=description)

    # If needed, add additional parameters.
    if mode == 'regression_ewc':
        dout_dir = './out_ewc/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser, show_beta=False,
                    show_from_scratch=True, show_multi_head=True)
        train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32,
            dn_iter=5001, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True,
            show_use_adadelta=True, show_use_adagrad=True, show_epochs=True,
            show_clip_grad_value=True, show_clip_grad_norm=True)
        cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10',
            dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True,
            show_no_batchnorm=False, show_bn_no_running_stats=True,
            show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True,
            show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1,
            show_net_act=True)
        eval_agroup = cli.eval_args(parser, show_val_batch_size=False,
                                    dval_batch_size=10000, dval_iter=250)
        rta.data_args(parser)
        misc_agroup = cli.miscellaneous_args(parser, big_data=False,
            synthetic_data=True, show_plots=True, no_cuda=False,
            dout_dir=dout_dir, show_publication_style=True)
        ewc_args(parser, dewc_lambda=1., dn_fisher=-1)

        pmta.special_train_options(train_agroup, show_soft_targets=False)

        rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False,
                    dval_sample_size=10)
        rta.train_args(train_agroup, show_ll_dist_std=True,
                       show_local_reparam_trick=False, show_kl_scale=False,
                       show_radial_bnn=False)
        rta.miscellaneous_args(misc_agroup, show_mnet_only=False,
            show_use_logvar_enc=False, show_disable_lrt_test=False,
            show_mean_only=False)

    elif mode == 'gmm_ewc':
        dout_dir = './out_gmm_ewc/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser, show_beta=False,
            show_from_scratch=True, show_multi_head=False,
            show_cl_scenario=True, show_split_head_cl3=True,
            show_num_tasks=False, show_num_classes_per_task=False)
        train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32,
            dn_iter=2000, dlr=1e-3, show_use_adam=True, show_use_rmsprop=True,
            show_use_adadelta=True, show_use_adagrad=True, show_epochs=True,
            show_clip_grad_value=True, show_clip_grad_norm=True)
        cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='10,10',
            dnet_act='sigmoid', show_no_bias=True, show_batchnorm=True,
            show_no_batchnorm=False, show_bn_no_running_stats=True,
            show_bn_distill_stats=False, show_bn_no_stats_checkpointing=True,
            show_specnorm=False, show_dropout_rate=True, ddropout_rate=-1,
            show_net_act=True)
        eval_agroup = cli.eval_args(parser, show_val_batch_size=True,
                                    dval_batch_size=10000, dval_iter=100)
        misc_agroup = cli.miscellaneous_args(parser, big_data=False,
            synthetic_data=True, show_plots=True, no_cuda=False,
            dout_dir=dout_dir)
        ewc_args(parser, dewc_lambda=1., dn_fisher=-1)

        pmta.special_train_options(train_agroup, show_soft_targets=False)

        cl_args(cl_argroup, show_det_multi_head=True)
        pmta.eval_args(eval_agroup)
        rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False,
                    dval_sample_size=10)
        rta.train_args(train_agroup, show_ll_dist_std=False,
                       show_local_reparam_trick=False, show_kl_scale=False,
                       show_radial_bnn=False)
        rta.miscellaneous_args(misc_agroup, show_mnet_only=False,
            show_use_logvar_enc=False, show_disable_lrt_test=False,
            show_mean_only=False, show_during_acc_criterion=True)

    elif mode == 'split_mnist_ewc':
        dout_dir = './out_split_ewc/run_' + \
                   datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser, show_beta=False,
                                 show_from_scratch=True, show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_split_head_cl3=True,
                                 show_num_tasks=True, dnum_tasks=5,
                                 show_num_classes_per_task=True,
                                 dnum_classes_per_task=2)
        train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128,
                                      dn_iter=2000, dlr=1e-3,
                                      show_use_adam=True, show_use_rmsprop=True,
                                      show_use_adadelta=True,
                                      show_use_adagrad=True, show_epochs=True,
                                      show_clip_grad_value=True,
                                      show_clip_grad_norm=True)
        cli.main_net_args(parser, allowed_nets=['mlp', 'lenet', 'resnet', 'wrn'],
                          dmlp_arch='400,400', dlenet_type='mnist_small',
                          show_no_bias=True, show_batchnorm=True,
                          show_no_batchnorm=False,
                          show_bn_no_running_stats=True,
                          show_bn_distill_stats=False,
                          show_bn_no_stats_checkpointing=True,
                          show_specnorm=False, show_dropout_rate=True,
                          ddropout_rate=-1, show_net_act=True)
        eval_agroup = cli.eval_args(parser, show_val_batch_size=True,
                dval_batch_size=1000, show_val_set_size=True, dval_set_size=0)
        misc_agroup = cli.miscellaneous_args(parser, big_data=False,
                                             synthetic_data=False,
                                             show_plots=False, no_cuda=False,
                                             dout_dir=dout_dir)
        ewc_args(parser, dewc_lambda=1., dn_fisher=-1)

        pmta.special_train_options(train_agroup, show_soft_targets=False)

        cl_args(cl_argroup, show_det_multi_head=True)
        pmta.eval_args(eval_agroup)
        rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False,
                    dval_sample_size=10)
        rta.train_args(train_agroup, show_ll_dist_std=False,
                       show_local_reparam_trick=False, show_kl_scale=False,
                       show_radial_bnn=False)
        rta.miscellaneous_args(misc_agroup, show_mnet_only=False,
                               show_use_logvar_enc=False,
                               show_disable_lrt_test=False,
                               show_mean_only=False,
                               show_during_acc_criterion=True)

    elif mode == 'perm_mnist_ewc':
        dout_dir = './out_perm_ewc/run_' + \
                   datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser, show_beta=False,
                                 show_from_scratch=True, show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_split_head_cl3=True,
                                 show_num_tasks=True, dnum_tasks=10,
                                 show_num_classes_per_task=False)
        train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=128,
                                      dn_iter=5000, dlr=1e-4,
                                      show_use_adam=True, show_use_rmsprop=True,
                                      show_use_adadelta=True,
                                      show_use_adagrad=True, show_epochs=True,
                                      show_clip_grad_value=True,
                                      show_clip_grad_norm=True)
        cli.main_net_args(parser, allowed_nets=['mlp'], dmlp_arch='1000,1000',
                          show_no_bias=True, show_batchnorm=True,
                          show_no_batchnorm=False,
                          show_bn_no_running_stats=True,
                          show_bn_distill_stats=False,
                          show_bn_no_stats_checkpointing=True,
                          show_specnorm=False, show_dropout_rate=True,
                          ddropout_rate=-1, show_net_act=True)
        eval_agroup = cli.eval_args(parser, show_val_batch_size=True,
            dval_batch_size=1000, show_val_set_size=True, dval_set_size=0)
        misc_agroup = cli.miscellaneous_args(parser, big_data=False,
            synthetic_data=True, show_plots=False, no_cuda=False,
            dout_dir=dout_dir)
        ewc_args(parser, dewc_lambda=1., dn_fisher=-1)

        pmta.special_train_options(train_agroup, show_soft_targets=False)

        cl_args(cl_argroup, show_det_multi_head=True)
        pmta.eval_args(eval_agroup)
        rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False,
                    dval_sample_size=10)
        rta.train_args(train_agroup, show_ll_dist_std=False,
                       show_local_reparam_trick=False, show_kl_scale=False,
                       show_radial_bnn=False)
        rta.miscellaneous_args(misc_agroup, show_mnet_only=False,
                               show_use_logvar_enc=False,
                               show_disable_lrt_test=False,
                               show_mean_only=False,
                               show_during_acc_criterion=True)

        pmta.perm_args(parser)

    elif mode == 'cifar_resnet_ewc':
        dout_dir = './out_resnet_ewc/run_' + \
                   datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser, show_beta=False,
                                 show_from_scratch=True, show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_split_head_cl3=True,
                                 show_num_tasks=True, dnum_tasks=6,
                                 show_num_classes_per_task=True,
                                 dnum_classes_per_task=10)
        pcta.extra_cl_args(cl_argroup)
        train_agroup = cli.train_args(parser, show_lr=True, dbatch_size=32,
                                      dn_iter=2000, dlr=1e-3,
                                      show_use_adam=True, show_use_rmsprop=True,
                                      show_use_adadelta=True,
                                      show_use_adagrad=True, show_epochs=True,
                                      depochs=200, show_clip_grad_value=True,
                                      show_clip_grad_norm=True)
        cli.main_net_args(parser,
            allowed_nets=['resnet', 'wrn', 'iresnet', 'lenet', 'zenke', 'mlp'],
            dmlp_arch='10,10', dlenet_type='cifar', show_no_bias=True,
            show_batchnorm=False, show_no_batchnorm=True,
            show_bn_no_running_stats=True, show_bn_distill_stats=False,
            show_bn_no_stats_checkpointing=True,
            show_specnorm=False, show_dropout_rate=True,
            ddropout_rate=-1, show_net_act=True)
        cli.data_args(parser, show_disable_data_augmentation=True)
        eval_agroup = cli.eval_args(parser, show_val_batch_size=True,
                                    dval_batch_size=1000,
                                    show_val_set_size=True, dval_set_size=0)
        misc_agroup = cli.miscellaneous_args(parser, big_data=False,
                                             synthetic_data=False,
                                             show_plots=False, no_cuda=False,
                                             dout_dir=dout_dir)
        ewc_args(parser, dewc_lambda=1., dn_fisher=-1)

        pmta.special_train_options(train_agroup, show_soft_targets=False)

        cl_args(cl_argroup, show_det_multi_head=True)
        pmta.eval_args(eval_agroup)
        rta.mc_args(train_agroup, eval_agroup, show_train_sample_size=False,
                    dval_sample_size=10)
        rta.train_args(train_agroup, show_ll_dist_std=False,
                       show_local_reparam_trick=False, show_kl_scale=False,
                       show_radial_bnn=False)
        rta.miscellaneous_args(misc_agroup, show_mnet_only=False,
                               show_use_logvar_enc=False,
                               show_disable_lrt_test=False,
                               show_mean_only=False,
                               show_during_acc_criterion=True)

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    rta.check_invalid_args_general(config)
    pmta.check_invalid_args_general(config)
    check_invalid_args_ewc(config)

    if mode == 'regression_ewc':
        if config.batchnorm:
            # Not properly handled in test and eval function!
            raise NotImplementedError()

    return config
Exemplo n.º 2
0
def parse_cmd_arguments(mode='regression_bbb', default=False, argv=None):
    """Parse command-line arguments.

    Args:
        mode: For what script should the parser assemble the set of command-line
            parameters? Options:

                - ``'regression_bbb'``
                - ``'regression_avb'``
                - ``'regression_ssge'``

        default (optional): If True, command-line arguments will be ignored and
            only the default values will be parsed.
        argv (optional): If provided, it will be treated as a list of command-
            line argument that is passed to the parser in place of sys.argv.

    Returns:
        The Namespace object containing argument names and values.
    """
    if mode == 'regression_bbb':
        description = 'Toy regression with tasks trained by BbB and ' + \
            'protected by a hypernetwork'
    elif mode == 'regression_avb':
        description = 'Toy regression with tasks trained by implicit model ' \
                      'using AVB and protected by a hypernetwork'
    elif mode == 'regression_ssge':
        description = 'Toy regression with tasks trained by implicit model ' \
                      'using SSGE and protected by a hypernetwork'
    else:
        raise ValueError('Mode "%s" unknown.' % (mode))

    parser = argparse.ArgumentParser(description=description)

    # Default hnet keyword arguments.
    hnet_args_kw = {  # Function `cli.hnet_args`
        # Note, the first list element denotes the default hnet.
        'allowed_nets': ['hmlp', 'chunked_hmlp', 'hdeconv', 'chunked_hdeconv'],
        'dhmlp_arch': '10,10',
        'show_cond_emb_size': True,
        'dcond_emb_size': 2,
        'dchmlp_chunk_size': 64,
        'dchunk_emb_size': 8,
        'show_use_cond_chunk_embs': True,
        'show_net_act': True,
        'dnet_act': 'sigmoid',
        'show_no_bias': True,
        'show_dropout_rate': True,
        'ddropout_rate': -1,
        'show_specnorm': True,
        'show_batchnorm': False,
        'show_no_batchnorm': False
    }

    # If needed, add additional parameters.
    if mode == 'regression_bbb':
        dout_dir = './out_bbb/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser,
                                 show_beta=True,
                                 dbeta=0.005,
                                 show_from_scratch=True,
                                 show_multi_head=True)
        train_argroup = cli.train_args(parser,
                                       show_lr=True,
                                       dn_iter=10001,
                                       dlr=1e-2,
                                       show_clip_grad_value=True,
                                       show_clip_grad_norm=True,
                                       show_momentum=False,
                                       show_adam_beta1=True)
        cli.main_net_args(parser,
                          allowed_nets=['mlp'],
                          dmlp_arch='10,10',
                          dnet_act='sigmoid',
                          show_no_bias=True)
        cli.hnet_args(parser, **hnet_args_kw)
        init_agroup = cli.init_args(parser, custom_option=False)
        eval_agroup = cli.eval_args(parser, dval_iter=250)
        data_args(parser)
        misc_agroup = cli.miscellaneous_args(parser,
                                             big_data=False,
                                             synthetic_data=True,
                                             show_plots=True,
                                             no_cuda=True,
                                             show_publication_style=True,
                                             dout_dir=dout_dir)

        mc_args(train_argroup, eval_agroup)
        train_args(train_argroup,
                   show_local_reparam_trick=True,
                   show_radial_bnn=True)
        cl_args(cl_argroup)
        init_args(init_agroup)
        miscellaneous_args(misc_agroup,
                           show_use_logvar_enc=True,
                           show_disable_lrt_test=True,
                           show_mean_only=True)
        pmta.train_args(train_argroup,
                        show_init_with_prev_emb=False,
                        show_use_prev_post_as_prior=True,
                        show_num_kl_samples=True,
                        show_training_set_size=False)

    #rtr config parameters for implicit model
    else:
        method = 'avb'
        if mode == 'regression_ssge':
            method = 'ssge'
        dout_dir = './out_%s/run_' % method + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser,
                                 show_beta=True,
                                 dbeta=0.005,
                                 show_from_scratch=True,
                                 show_multi_head=True)
        train_argroup = cli.train_args(parser,
                                       show_lr=True,
                                       dbatch_size=32,
                                       dn_iter=5001,
                                       dlr=1e-3,
                                       show_use_adam=True,
                                       show_use_rmsprop=True,
                                       show_use_adadelta=True,
                                       show_use_adagrad=True,
                                       show_epochs=True,
                                       show_clip_grad_value=True,
                                       show_clip_grad_norm=True)
        # Main network.
        cli.main_net_args(parser,
                          allowed_nets=['mlp'],
                          dmlp_arch='10,10',
                          dnet_act='sigmoid',
                          show_no_bias=True)
        if mode == 'regression_avb':
            # Discriminator.
            cli.main_net_args(parser,
                              allowed_nets=['mlp', 'chunked_mlp'],
                              dmlp_arch='10,10',
                              dcmlp_arch='10,10',
                              dcmlp_chunk_arch='10,10',
                              dcmlp_in_cdim=32,
                              dcmlp_out_cdim=8,
                              dcmlp_cemb_dim=8,
                              dnet_act='sigmoid',
                              show_no_bias=True,
                              prefix='dis_',
                              pf_name='discriminator')
        # Hypernetwork (weight generator).
        imp_hargs_kw = dict(hnet_args_kw)
        imp_hargs_kw['show_cond_emb_size'] = False
        imp_hargs_kw['show_use_cond_chunk_embs'] = False
        imp_hargs_kw['dcond_emb_size'] = 0  # Not used for implicit hnet!
        cli.hnet_args(parser,
                      **imp_hargs_kw,
                      prefix='imp_',
                      pf_name='implicit')
        # Hyper-hypernetwork.
        hhargs_kw = dict(hnet_args_kw)
        cli.hnet_args(parser, **hhargs_kw, prefix='hh_', pf_name='hyper-hyper')
        cli.init_args(parser, custom_option=False, show_hyper_fan_init=True)
        eval_agroup = cli.eval_args(parser, dval_iter=250)
        data_args(parser)
        misc_agroup = cli.miscellaneous_args(parser,
                                             big_data=False,
                                             synthetic_data=True,
                                             show_plots=True,
                                             no_cuda=True,
                                             show_publication_style=True,
                                             dout_dir=dout_dir)

        pmta.special_train_options(train_argroup, show_soft_targets=False)
        mc_args(train_argroup, eval_agroup)
        train_args(train_argroup,
                   show_local_reparam_trick=False,
                   show_kl_scale=True)
        miscellaneous_args(misc_agroup, show_store_during_models=True)

        pmta.train_args(train_argroup,
                        show_init_with_prev_emb=True,
                        show_use_prev_post_as_prior=True,
                        show_num_kl_samples=True,
                        show_calc_hnet_reg_targets_online=True,
                        show_hnet_reg_batch_size=True,
                        show_training_set_size=False)
        pmta.ind_posterior_args(train_argroup, show_distill_iter=False)
        pcta.miscellaneous_args(misc_agroup, show_no_hhnet=True)
        pcta.imp_args(parser, dlatent_dim=8, show_prior_focused=True)

        if mode == 'regression_avb':
            pcta.avb_args(parser)
        elif mode == 'regression_ssge':
            pcta.ssge_args(parser)

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    check_invalid_args_general(config)
    pmta.check_invalid_args_general(config)

    if mode == 'regression_bbb':
        check_invalid_bbb_args(config)
    else:
        pass

    return config
Exemplo n.º 3
0
def parse_cmd_arguments(mode='resnet_cifar', default=False, argv=None):
    """Parse command-line arguments.

    Args:
        mode (str): For what script should the parser assemble the set of
            command-line parameters? Options:

                - ``resnet_cifar``
                - ``zenke_cifar``

        default (bool, optional): If ``True``, command-line arguments will be
            ignored and only the default values will be parsed.
        argv (list, optional): If provided, it will be treated as a list of
            command- line argument that is passed to the parser in place of
            :code:`sys.argv`.

    Returns:
        (argparse.Namespace): The Namespace object containing argument names and
            values.
    """
    if mode == 'resnet_cifar':
        description = 'CIFAR-10/100 CL experiment using a Resnet-32'
    elif mode == 'zenke_cifar':
        description = 'CIFAR-10/100 CL experiment using the Zenkenet'
    else:
        raise ValueError('Mode "%s" unknown.' % (mode))

    parser = argparse.ArgumentParser(description=description)

    general_options(parser)

    if mode == 'resnet_cifar':
        dout_dir = './out_resnet/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_group = cli.cl_args(parser,
                               show_beta=True,
                               dbeta=0.05,
                               show_from_scratch=True,
                               show_multi_head=False,
                               show_cl_scenario=True,
                               show_split_head_cl3=False,
                               show_num_tasks=True,
                               dnum_tasks=6)
        cli.main_net_args(parser,
                          allowed_nets=['resnet'],
                          show_batchnorm=False,
                          show_no_batchnorm=True,
                          show_bn_no_running_stats=True,
                          show_bn_distill_stats=True,
                          show_bn_no_stats_checkpointing=True,
                          show_specnorm=False,
                          show_dropout_rate=False,
                          show_net_act=False)
        cli.hypernet_args(parser,
                          dhyper_chunks=7000,
                          dhnet_arch='',
                          dtemb_size=32,
                          demb_size=32)
        cli.data_args(parser, show_disable_data_augmentation=True)
        train_agroup = cli.train_args(parser,
                                      show_lr=True,
                                      dlr=0.001,
                                      show_epochs=True,
                                      depochs=200,
                                      dbatch_size=32,
                                      dn_iter=2000,
                                      show_use_adam=True,
                                      show_use_rmsprop=True,
                                      show_use_adadelta=False,
                                      show_use_adagrad=False,
                                      show_clip_grad_value=False,
                                      show_clip_grad_norm=False)

    elif mode == 'zenke_cifar':
        dout_dir = './out_zenke/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_group = cli.cl_args(parser,
                               show_beta=True,
                               dbeta=0.01,
                               show_from_scratch=True,
                               show_multi_head=False,
                               show_cl_scenario=True,
                               show_split_head_cl3=False,
                               show_num_tasks=True,
                               dnum_tasks=6)
        cli.main_net_args(parser,
                          allowed_nets=['zenke'],
                          show_batchnorm=False,
                          show_no_batchnorm=False,
                          show_dropout_rate=True,
                          ddropout_rate=0.25,
                          show_specnorm=False,
                          show_net_act=False)
        cli.hypernet_args(parser,
                          dhyper_chunks=5500,
                          dhnet_arch='100,150,200',
                          dtemb_size=48,
                          demb_size=80)
        cli.data_args(parser, show_disable_data_augmentation=True)
        train_agroup = cli.train_args(parser,
                                      show_lr=True,
                                      dlr=0.0001,
                                      show_epochs=True,
                                      depochs=80,
                                      dbatch_size=256,
                                      dn_iter=2000,
                                      show_use_adam=True,
                                      dadam_beta1=0.5,
                                      show_use_rmsprop=True,
                                      show_use_adadelta=False,
                                      show_use_adagrad=False,
                                      show_clip_grad_value=False,
                                      show_clip_grad_norm=False)

    special_cl_options(cl_group)
    special_train_options(train_agroup)
    init_group = cli.init_args(parser, custom_option=True)
    special_init_options(init_group)
    cli.eval_args(parser, show_val_batch_size=True, dval_batch_size=1000)
    cli.miscellaneous_args(parser,
                           big_data=False,
                           synthetic_data=False,
                           show_plots=False,
                           no_cuda=False,
                           dout_dir=dout_dir)

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)

    ### Check argument values!
    cli.check_invalid_argument_usage(config)

    ### ... insert additional checks if necessary
    if config.num_tasks < 1 or config.num_tasks > 11:
        raise ValueError('Argument "num_tasks" must be between 1 and 11!')

    if config.cl_scenario != 1:
        raise NotImplementedError('CIFAR experiments are currently only ' +
                                  'implemented for CL1.')

    if config.plateau_lr_scheduler and config.epochs == -1:
        raise ValueError('Flag "plateau_lr_scheduler" can only be used if ' +
                         '"epochs" was set.')

    if config.lambda_lr_scheduler and config.epochs == -1:
        raise ValueError('Flag "lambda_lr_scheduler" can only be used if ' +
                         '"epochs" was set.')

    if config.no_lookahead and config.backprop_dt:
        raise ValueError('Can\'t activate "no_lookahead" and "backprop_dt" ' +
                         'simultaneously.')

    return config
def parse_cmd_arguments(default=False, argv=None):
    """Parse command-line arguments.

    Args:
        default (optional): If True, command-line arguments will be ignored and
            only the default values will be parsed.
        argv (optional): If provided, it will be treated as a list of command-
            line argument that is passed to the parser in place of sys.argv.

    Returns:
        The Namespace object containing argument names and values.
    """

    description = 'Continual learning on copy task.'

    parser = argparse.ArgumentParser(description=description)

    cli.cl_args(parser,
                show_beta=True,
                dbeta=0.005,
                show_from_scratch=True,
                show_multi_head=True,
                show_split_head_cl3=False,
                show_cl_scenario=False,
                show_num_tasks=True,
                dnum_tasks=6)
    cli.train_args(parser,
                   show_lr=True,
                   show_epochs=False,
                   dbatch_size=64,
                   dn_iter=5000,
                   dlr=1e-3,
                   show_clip_grad_value=False,
                   show_clip_grad_norm=True,
                   show_momentum=False,
                   show_adam_beta1=True)
    seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh')
    cli.hypernet_args(parser,
                      dhyper_chunks=-1,
                      dhnet_arch='10,10',
                      dtemb_size=2,
                      demb_size=32,
                      dhnet_act='sigmoid')
    # Args of new hnets.
    nhnet_args = cli.hnet_args(parser,
                               allowed_nets=[
                                   'hmlp', 'chunked_hmlp', 'structured_hmlp',
                                   'hdeconv', 'chunked_hdeconv'
                               ],
                               dhmlp_arch='50,50',
                               show_cond_emb_size=True,
                               dcond_emb_size=32,
                               dchmlp_chunk_size=1000,
                               dchunk_emb_size=32,
                               show_use_cond_chunk_embs=True,
                               dhdeconv_shape='512,512,3',
                               prefix='nh_',
                               pf_name='new edition of a hyper-',
                               show_net_act=True,
                               dnet_act='relu',
                               show_no_bias=True,
                               show_dropout_rate=True,
                               ddropout_rate=-1,
                               show_specnorm=True,
                               show_batchnorm=False,
                               show_no_batchnorm=False)
    seq.new_hnet_args(nhnet_args)
    cli.init_args(parser,
                  custom_option=False,
                  show_normal_init=False,
                  show_hyper_fan_init=True)
    cli.eval_args(parser, dval_iter=250)
    magroup = cli.miscellaneous_args(parser,
                                     big_data=False,
                                     synthetic_data=True,
                                     show_plots=True,
                                     no_cuda=True,
                                     show_publication_style=False)
    seq.ewc_args(parser,
                 dewc_lambda=5000.,
                 dn_fisher=-1,
                 dtbptt_fisher=-1,
                 show_ts_weighting_fisher=False)
    seq.si_args(parser, dsi_lambda=1.)
    seq.context_mod_args(parser,
                         dsparsification_reg_type='l1',
                         dsparsification_reg_strength=1.,
                         dcontext_mod_init='constant')
    seq.miscellaneous_args(magroup,
                           dmask_fraction=0.8,
                           dclassification=True,
                           show_ts_weighting=False,
                           show_use_ce_loss=False,
                           show_early_stopping_thld=True,
                           dearly_stopping_thld=-1)
    copy_sequence_args(parser)

    # Replay arguments.
    rep_args = seq.replay_args(parser, show_all_task_softmax=False)
    cli.generator_args(rep_args, dlatent_dim=100)
    cli.main_net_args(parser,
                      allowed_nets=['simple_rnn'],
                      dsrnn_rec_layers='256',
                      dsrnn_pre_fc_layers='',
                      dsrnn_post_fc_layers='',
                      show_net_act=True,
                      dnet_act='tanh',
                      show_no_bias=True,
                      show_dropout_rate=False,
                      show_specnorm=False,
                      show_batchnorm=False,
                      prefix='dec_',
                      pf_name='replay decoder')

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    seq.check_invalid_args_sequential(config)
    check_invalid_args_sequential(config)

    if config.train_from_scratch:
        # FIXME We could get rid of this warning by properly checkpointing and
        # loading all networks.
        warnings.warn('When training from scratch, only during accuracies ' +
                      'make sense. All other outputs should be ignored!')

    return config
Exemplo n.º 5
0
def parse_cmd_arguments(mode='split_smnist', default=False, argv=None):
    """Parse command-line arguments.

    Args:
        mode (str): The mode of the experiment. 
        default (optional): If True, command-line arguments will be ignored and
            only the default values will be parsed.
        argv (optional): If provided, it will be treated as a list of command-
            line argument that is passed to the parser in place of sys.argv.

    Returns:
        The Namespace object containing argument names and values.
    """

    description = 'Continual learning on SMNIST task.'

    parser = argparse.ArgumentParser(description=description)

    dnum_tasks = 1
    dnum_classes_per_task = 10
    if mode == 'split_smnist':
        dnum_tasks = 5
        dnum_classes_per_task = 2
        dval_set_size = 100
    if mode == 'smnist':
        dnum_tasks = 1
        dnum_classes_per_task = 10
        dval_set_size = 500

    cli.cl_args(parser,
                show_beta=True,
                dbeta=0.005,
                show_from_scratch=True,
                show_multi_head=True,
                show_split_head_cl3=False,
                show_cl_scenario=False,
                show_num_tasks=True,
                dnum_tasks=dnum_tasks,
                show_num_classes_per_task=True,
                dnum_classes_per_task=dnum_classes_per_task)
    cli.train_args(parser,
                   show_lr=True,
                   show_epochs=False,
                   dbatch_size=64,
                   dn_iter=5000,
                   dlr=1e-3,
                   show_clip_grad_value=False,
                   show_clip_grad_norm=True,
                   show_momentum=False,
                   show_adam_beta1=True)
    seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh')
    cli.hypernet_args(parser,
                      dhyper_chunks=-1,
                      dhnet_arch='50,50',
                      dtemb_size=32,
                      demb_size=32,
                      dhnet_act='relu')
    # Args of new hnets.
    nhnet_args = cli.hnet_args(parser,
                               allowed_nets=[
                                   'hmlp', 'chunked_hmlp', 'structured_hmlp',
                                   'hdeconv', 'chunked_hdeconv'
                               ],
                               dhmlp_arch='50,50',
                               show_cond_emb_size=True,
                               dcond_emb_size=32,
                               dchmlp_chunk_size=1000,
                               dchunk_emb_size=32,
                               show_use_cond_chunk_embs=True,
                               dhdeconv_shape='512,512,3',
                               prefix='nh_',
                               pf_name='new edition of a hyper-',
                               show_net_act=True,
                               dnet_act='relu',
                               show_no_bias=True,
                               show_dropout_rate=True,
                               ddropout_rate=-1,
                               show_specnorm=True,
                               show_batchnorm=False,
                               show_no_batchnorm=False)
    seq.new_hnet_args(nhnet_args)
    cli.init_args(parser,
                  custom_option=False,
                  show_normal_init=False,
                  show_hyper_fan_init=True)
    cli.eval_args(parser,
                  dval_iter=250,
                  show_val_set_size=True,
                  dval_set_size=dval_set_size)
    magroup = cli.miscellaneous_args(parser,
                                     big_data=False,
                                     synthetic_data=True,
                                     show_plots=True,
                                     no_cuda=True,
                                     show_publication_style=False)
    seq.ewc_args(parser,
                 dewc_lambda=5000.,
                 dn_fisher=-1,
                 dtbptt_fisher=-1,
                 dts_weighting_fisher='last')
    seq.si_args(parser, dsi_lambda=1.)
    seq.context_mod_args(parser,
                         dsparsification_reg_type='l1',
                         dsparsification_reg_strength=1.,
                         dcontext_mod_init='constant')
    seq.miscellaneous_args(magroup,
                           dmask_fraction=0.8,
                           dclassification=True,
                           dts_weighting='last',
                           show_use_ce_loss=False)
    # Replay arguments.
    rep_args = seq.replay_args(parser)
    cli.generator_args(rep_args, dlatent_dim=100)
    cli.main_net_args(parser,
                      allowed_nets=['simple_rnn'],
                      dsrnn_rec_layers='256',
                      dsrnn_pre_fc_layers='',
                      dsrnn_post_fc_layers='',
                      show_net_act=True,
                      dnet_act='tanh',
                      show_no_bias=True,
                      show_dropout_rate=False,
                      show_specnorm=False,
                      show_batchnorm=False,
                      prefix='dec_',
                      pf_name='replay decoder')

    rep_args.add_argument('--distill_across_time',
                          action='store_true',
                          help='The feature vector of an SMNIST sample ' +
                          'contains an "end-of-digit" bit, which in the ' +
                          'original data is only hot for 1 timestep. ' +
                          'However, in replayed data, this information ' +
                          'might be blurry. If this option is ' +
                          'activated, the distillation loss will be ' +
                          'applied to every timestep weighted by the ' +
                          'softmaxed "end-of-digit" feature from the ' +
                          'replayed input. Otherwise, the argmax ' +
                          'timestep of the "end-of-digit" feature is ' +
                          'considered.')

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)
    config.mode = mode

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    seq.check_invalid_args_sequential(config)

    if config.train_from_scratch:
        # FIXME We could get rid of this warning by properly checkpointing and
        # loading all networks.
        warnings.warn('When training from scratch, only during accuracies ' +
                      'make sense. All other outputs should be ignored!')

    return config
Exemplo n.º 6
0
def collect_rp_cmd_arguments(mode='split', description=""):
    """Collect command-line arguments.

    Args:
        mode: For what script should the parser assemble the set of command-line
            parameters? Options:

                - "split"
                - "perm"
    Returns:
        The Namespace object containing argument names and values.
    """
    parser = argparse.ArgumentParser(description=description)

    # If needed, add additional parameters.
    if mode == 'split':
        dout_dir = './out_split/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        cl_argroup = cli.cl_args(parser,
                                 show_beta=False,
                                 show_from_scratch=True,
                                 show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_num_tasks=True,
                                 dnum_tasks=5)
        train_argroup = cli.train_args(parser,
                                       show_lr=False,
                                       dbatch_size=128,
                                       dn_iter=2000,
                                       show_epochs=True)
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='enc_',
                          pf_name='encoder')
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='dec_',
                          pf_name='decoder')
        cli.hypernet_args(parser,
                          dhyper_chunks=50000,
                          dhnet_arch='10,10',
                          dtemb_size=96,
                          demb_size=96,
                          prefix='rp_',
                          pf_name='replay',
                          dhnet_act='elu')
        cli.init_args(parser, custom_option=False)
        cli.miscellaneous_args(parser,
                               big_data=False,
                               synthetic_data=False,
                               show_plots=True,
                               no_cuda=False,
                               dout_dir=dout_dir)
        cli.generator_args(parser, dlatent_dim=100)
        cli.eval_args(parser, dval_iter=1000)

        train_args_replay(parser, prefix='enc_', pf_name='encoder')
        train_args_replay(parser,
                          show_emb_lr=True,
                          prefix='dec_',
                          pf_name='decoder')
        split_args(parser)

    elif mode == 'perm':
        dout_dir = './out_permuted/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        cl_argroup = cli.cl_args(parser,
                                 show_beta=False,
                                 show_from_scratch=True,
                                 show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_num_tasks=True,
                                 dnum_tasks=10)
        train_argroup = cli.train_args(parser,
                                       show_lr=False,
                                       dbatch_size=128,
                                       dn_iter=5000,
                                       show_epochs=True)
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='1000,1000',
                          dnet_act='relu',
                          prefix='enc_',
                          pf_name='encoder')
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='1000,1000',
                          dnet_act='relu',
                          prefix='dec_',
                          pf_name='decoder')
        cli.hypernet_args(parser,
                          dhyper_chunks=85000,
                          dhnet_arch='25,25',
                          dtemb_size=24,
                          demb_size=8,
                          prefix='rp_',
                          pf_name='replay',
                          dhnet_act='elu')
        cli.init_args(parser, custom_option=False)
        cli.miscellaneous_args(parser,
                               big_data=False,
                               synthetic_data=True,
                               show_plots=True,
                               no_cuda=False,
                               dout_dir=dout_dir)
        cli.generator_args(parser, dlatent_dim=100)
        cli.eval_args(parser, dval_iter=1000)
        train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001)
        train_args_replay(parser,
                          show_emb_lr=True,
                          prefix='dec_',
                          dlr=0.0001,
                          dlr_emb=0.0001,
                          pf_name='Decoder')

        perm_args(parser)

    cl_arguments_replay(parser)
    cl_arguments_general(parser)
    data_args(parser)

    return parser
Exemplo n.º 7
0
def collect_rp_cmd_arguments(mode='mlp', description="", argv=0):
    print('collect_rp_cmd_arguments')
    """Collect command-line arguments.

    Args:
        mode: For what script should the parser assemble the set of command-line
            parameters? Options:

                - "split"
                - "perm"
    Returns:
        The Namespace object containing argument names and values.
    """
    parser = argparse.ArgumentParser(description=description)

    agroup = parser.add_argument_group('note')
    agroup.add_argument('--note',
                        default='small,random,sep-emnist',
                        type=str,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--dis_ntasks',
                        default=10,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--sim_ntasks',
                        default=10,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--idrandom',
                        default=0,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--classptask',
                        default=5,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--pdrop1',
                        default=-1,
                        type=float,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--pdrop2',
                        default=-1,
                        type=float,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--dims',
                        default=[],
                        type=list,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--tasks_preserve',
                        type=int,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--upper_bound',
                        action='store_true',
                        help='Train the classifier with "replay" data i.e ' +
                        'real data. This can be regarded an upper bound.')
    agroup.add_argument("--num_class_femnist",
                        default=62,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')

    config = parser.parse_args(args=argv)

    # If needed, add additional parameters.
    if mode == 'mlp':
        # dout_dir = './out_split/run_' + \
        # datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        dout_dir = './out_split/' + config.note

        cl_argroup = cli.cl_args(parser,
                                 show_beta=False,
                                 show_from_scratch=True,
                                 show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_num_tasks=True,
                                 dnum_tasks=config.tasks_preserve)
        train_argroup = cli.train_args(parser,
                                       show_lr=False,
                                       dbatch_size=128,
                                       dn_iter=2000,
                                       show_epochs=True)
        cli.main_net_args(parser,
                          allowed_nets=['mlp'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='enc_',
                          pf_name='encoder')
        cli.main_net_args(parser,
                          allowed_nets=['mlp'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='dec_',
                          pf_name='decoder')
        cli.hypernet_args(parser,
                          dhyper_chunks=50000,
                          dhnet_arch='10,10',
                          dtemb_size=96,
                          demb_size=96,
                          prefix='rp_',
                          pf_name='replay',
                          dhnet_act='elu')
        cli.init_args(parser, custom_option=False)
        cli.miscellaneous_args(parser,
                               big_data=False,
                               synthetic_data=False,
                               show_plots=True,
                               no_cuda=False,
                               dout_dir=dout_dir)
        cli.generator_args(parser, dlatent_dim=100)
        cli.eval_args(parser, dval_iter=1000)

        train_args_replay(parser, prefix='enc_', pf_name='encoder')
        train_args_replay(parser,
                          show_emb_lr=True,
                          prefix='dec_',
                          pf_name='decoder')
        split_args(parser)

    cl_arguments_replay(parser)
    cl_arguments_general(parser)
    data_args(parser)

    return parser
Exemplo n.º 8
0
def collect_rp_cmd_arguments(mode='split', description="", argv=0):
    """Collect command-line arguments.

    Args:
        mode: For what script should the parser assemble the set of command-line
            parameters? Options:

                - "split"
                - "perm"
    Returns:
        The Namespace object containing argument names and values.
    """
    parser = argparse.ArgumentParser(description=description)

    agroup = parser.add_argument_group('note')
    agroup.add_argument('--note',
                        default='small,random,sep-emnist',
                        type=str,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--ntasks',
                        default=10,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--idrandom',
                        default=0,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--classptask',
                        default=5,
                        type=int,
                        required=False,
                        help='(default=%(default)d)')
    agroup.add_argument('--pdrop1',
                        default=-1,
                        type=float,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--pdrop2',
                        default=-1,
                        type=float,
                        required=False,
                        help='(default=%(default)f)')
    agroup.add_argument('--taskcla',
                        default=[],
                        type=list,
                        required=False,
                        help='(default=%(default)f)')

    config = parser.parse_args(args=argv)

    # If needed, add additional parameters.
    if mode == 'split':
        # dout_dir = './out_split/run_' + \
        # datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        dout_dir = './out_split/' + config.note

        cl_argroup = cli.cl_args(parser,
                                 show_beta=False,
                                 show_from_scratch=True,
                                 show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_num_tasks=True,
                                 dnum_tasks=config.ntasks)
        train_argroup = cli.train_args(parser,
                                       show_lr=False,
                                       dbatch_size=64,
                                       dn_iter=2000,
                                       show_epochs=True)
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='enc_',
                          pf_name='encoder')
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='400,400',
                          dnet_act='relu',
                          prefix='dec_',
                          pf_name='decoder')
        cli.hypernet_args(parser,
                          dhyper_chunks=50000,
                          dhnet_arch='10,10',
                          dtemb_size=96,
                          demb_size=96,
                          prefix='rp_',
                          pf_name='replay',
                          dhnet_act='elu')
        cli.init_args(parser, custom_option=False)
        cli.miscellaneous_args(parser,
                               big_data=False,
                               synthetic_data=False,
                               show_plots=True,
                               no_cuda=False,
                               dout_dir=dout_dir)
        cli.generator_args(parser, dlatent_dim=100)
        cli.eval_args(parser, dval_iter=1000)

        train_args_replay(parser, prefix='enc_', pf_name='encoder')
        train_args_replay(parser,
                          show_emb_lr=True,
                          prefix='dec_',
                          pf_name='decoder')
        split_args(parser)

    elif mode == 'perm':
        dout_dir = './out_permuted/run_' + \
            datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        cl_argroup = cli.cl_args(parser,
                                 show_beta=False,
                                 show_from_scratch=True,
                                 show_multi_head=False,
                                 show_cl_scenario=True,
                                 show_num_tasks=True,
                                 dnum_tasks=config.ntasks)
        train_argroup = cli.train_args(parser,
                                       show_lr=False,
                                       dbatch_size=64,
                                       dn_iter=5000,
                                       show_epochs=True)
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='1000,1000',
                          dnet_act='relu',
                          prefix='enc_',
                          pf_name='encoder')
        cli.main_net_args(parser,
                          allowed_nets=['fc'],
                          dfc_arch='1000,1000',
                          dnet_act='relu',
                          prefix='dec_',
                          pf_name='decoder')
        cli.hypernet_args(parser,
                          dhyper_chunks=85000,
                          dhnet_arch='25,25',
                          dtemb_size=24,
                          demb_size=8,
                          prefix='rp_',
                          pf_name='replay',
                          dhnet_act='elu')
        cli.init_args(parser, custom_option=False)
        cli.miscellaneous_args(parser,
                               big_data=False,
                               synthetic_data=True,
                               show_plots=True,
                               no_cuda=False,
                               dout_dir=dout_dir)
        cli.generator_args(parser, dlatent_dim=100)
        cli.eval_args(parser, dval_iter=1000)
        train_args_replay(parser, prefix='enc_', pf_name='Encoder', dlr=0.0001)
        train_args_replay(parser,
                          show_emb_lr=True,
                          prefix='dec_',
                          dlr=0.0001,
                          dlr_emb=0.0001,
                          pf_name='Decoder')

        perm_args(parser)

    cl_arguments_replay(parser)
    cl_arguments_general(parser)
    data_args(parser)

    return parser