def parse_cmd_arguments(default=False, argv=None):
    """Parse command-line arguments.

    Args:
        default (optional): If True, command-line arguments will be ignored and
            only the default values will be parsed.
        argv (optional): If provided, it will be treated as a list of command-
            line argument that is passed to the parser in place of sys.argv.

    Returns:
        The Namespace object containing argument names and values.
    """

    description = 'Continual learning on copy task.'

    parser = argparse.ArgumentParser(description=description)

    cli.cl_args(parser,
                show_beta=True,
                dbeta=0.005,
                show_from_scratch=True,
                show_multi_head=True,
                show_split_head_cl3=False,
                show_cl_scenario=False,
                show_num_tasks=True,
                dnum_tasks=6)
    cli.train_args(parser,
                   show_lr=True,
                   show_epochs=False,
                   dbatch_size=64,
                   dn_iter=5000,
                   dlr=1e-3,
                   show_clip_grad_value=False,
                   show_clip_grad_norm=True,
                   show_momentum=False,
                   show_adam_beta1=True)
    seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh')
    cli.hypernet_args(parser,
                      dhyper_chunks=-1,
                      dhnet_arch='10,10',
                      dtemb_size=2,
                      demb_size=32,
                      dhnet_act='sigmoid')
    # Args of new hnets.
    nhnet_args = cli.hnet_args(parser,
                               allowed_nets=[
                                   'hmlp', 'chunked_hmlp', 'structured_hmlp',
                                   'hdeconv', 'chunked_hdeconv'
                               ],
                               dhmlp_arch='50,50',
                               show_cond_emb_size=True,
                               dcond_emb_size=32,
                               dchmlp_chunk_size=1000,
                               dchunk_emb_size=32,
                               show_use_cond_chunk_embs=True,
                               dhdeconv_shape='512,512,3',
                               prefix='nh_',
                               pf_name='new edition of a hyper-',
                               show_net_act=True,
                               dnet_act='relu',
                               show_no_bias=True,
                               show_dropout_rate=True,
                               ddropout_rate=-1,
                               show_specnorm=True,
                               show_batchnorm=False,
                               show_no_batchnorm=False)
    seq.new_hnet_args(nhnet_args)
    cli.init_args(parser,
                  custom_option=False,
                  show_normal_init=False,
                  show_hyper_fan_init=True)
    cli.eval_args(parser, dval_iter=250)
    magroup = cli.miscellaneous_args(parser,
                                     big_data=False,
                                     synthetic_data=True,
                                     show_plots=True,
                                     no_cuda=True,
                                     show_publication_style=False)
    seq.ewc_args(parser,
                 dewc_lambda=5000.,
                 dn_fisher=-1,
                 dtbptt_fisher=-1,
                 show_ts_weighting_fisher=False)
    seq.si_args(parser, dsi_lambda=1.)
    seq.context_mod_args(parser,
                         dsparsification_reg_type='l1',
                         dsparsification_reg_strength=1.,
                         dcontext_mod_init='constant')
    seq.miscellaneous_args(magroup,
                           dmask_fraction=0.8,
                           dclassification=True,
                           show_ts_weighting=False,
                           show_use_ce_loss=False,
                           show_early_stopping_thld=True,
                           dearly_stopping_thld=-1)
    copy_sequence_args(parser)

    # Replay arguments.
    rep_args = seq.replay_args(parser, show_all_task_softmax=False)
    cli.generator_args(rep_args, dlatent_dim=100)
    cli.main_net_args(parser,
                      allowed_nets=['simple_rnn'],
                      dsrnn_rec_layers='256',
                      dsrnn_pre_fc_layers='',
                      dsrnn_post_fc_layers='',
                      show_net_act=True,
                      dnet_act='tanh',
                      show_no_bias=True,
                      show_dropout_rate=False,
                      show_specnorm=False,
                      show_batchnorm=False,
                      prefix='dec_',
                      pf_name='replay decoder')

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    seq.check_invalid_args_sequential(config)
    check_invalid_args_sequential(config)

    if config.train_from_scratch:
        # FIXME We could get rid of this warning by properly checkpointing and
        # loading all networks.
        warnings.warn('When training from scratch, only during accuracies ' +
                      'make sense. All other outputs should be ignored!')

    return config
Пример #2
0
def parse_cmd_arguments(mode='split_smnist', default=False, argv=None):
    """Parse command-line arguments.

    Args:
        mode (str): The mode of the experiment. 
        default (optional): If True, command-line arguments will be ignored and
            only the default values will be parsed.
        argv (optional): If provided, it will be treated as a list of command-
            line argument that is passed to the parser in place of sys.argv.

    Returns:
        The Namespace object containing argument names and values.
    """

    description = 'Continual learning on SMNIST task.'

    parser = argparse.ArgumentParser(description=description)

    dnum_tasks = 1
    dnum_classes_per_task = 10
    if mode == 'split_smnist':
        dnum_tasks = 5
        dnum_classes_per_task = 2
        dval_set_size = 100
    if mode == 'smnist':
        dnum_tasks = 1
        dnum_classes_per_task = 10
        dval_set_size = 500

    cli.cl_args(parser,
                show_beta=True,
                dbeta=0.005,
                show_from_scratch=True,
                show_multi_head=True,
                show_split_head_cl3=False,
                show_cl_scenario=False,
                show_num_tasks=True,
                dnum_tasks=dnum_tasks,
                show_num_classes_per_task=True,
                dnum_classes_per_task=dnum_classes_per_task)
    cli.train_args(parser,
                   show_lr=True,
                   show_epochs=False,
                   dbatch_size=64,
                   dn_iter=5000,
                   dlr=1e-3,
                   show_clip_grad_value=False,
                   show_clip_grad_norm=True,
                   show_momentum=False,
                   show_adam_beta1=True)
    seq.rnn_args(parser, drnn_arch='256', dnet_act='tanh')
    cli.hypernet_args(parser,
                      dhyper_chunks=-1,
                      dhnet_arch='50,50',
                      dtemb_size=32,
                      demb_size=32,
                      dhnet_act='relu')
    # Args of new hnets.
    nhnet_args = cli.hnet_args(parser,
                               allowed_nets=[
                                   'hmlp', 'chunked_hmlp', 'structured_hmlp',
                                   'hdeconv', 'chunked_hdeconv'
                               ],
                               dhmlp_arch='50,50',
                               show_cond_emb_size=True,
                               dcond_emb_size=32,
                               dchmlp_chunk_size=1000,
                               dchunk_emb_size=32,
                               show_use_cond_chunk_embs=True,
                               dhdeconv_shape='512,512,3',
                               prefix='nh_',
                               pf_name='new edition of a hyper-',
                               show_net_act=True,
                               dnet_act='relu',
                               show_no_bias=True,
                               show_dropout_rate=True,
                               ddropout_rate=-1,
                               show_specnorm=True,
                               show_batchnorm=False,
                               show_no_batchnorm=False)
    seq.new_hnet_args(nhnet_args)
    cli.init_args(parser,
                  custom_option=False,
                  show_normal_init=False,
                  show_hyper_fan_init=True)
    cli.eval_args(parser,
                  dval_iter=250,
                  show_val_set_size=True,
                  dval_set_size=dval_set_size)
    magroup = cli.miscellaneous_args(parser,
                                     big_data=False,
                                     synthetic_data=True,
                                     show_plots=True,
                                     no_cuda=True,
                                     show_publication_style=False)
    seq.ewc_args(parser,
                 dewc_lambda=5000.,
                 dn_fisher=-1,
                 dtbptt_fisher=-1,
                 dts_weighting_fisher='last')
    seq.si_args(parser, dsi_lambda=1.)
    seq.context_mod_args(parser,
                         dsparsification_reg_type='l1',
                         dsparsification_reg_strength=1.,
                         dcontext_mod_init='constant')
    seq.miscellaneous_args(magroup,
                           dmask_fraction=0.8,
                           dclassification=True,
                           dts_weighting='last',
                           show_use_ce_loss=False)
    # Replay arguments.
    rep_args = seq.replay_args(parser)
    cli.generator_args(rep_args, dlatent_dim=100)
    cli.main_net_args(parser,
                      allowed_nets=['simple_rnn'],
                      dsrnn_rec_layers='256',
                      dsrnn_pre_fc_layers='',
                      dsrnn_post_fc_layers='',
                      show_net_act=True,
                      dnet_act='tanh',
                      show_no_bias=True,
                      show_dropout_rate=False,
                      show_specnorm=False,
                      show_batchnorm=False,
                      prefix='dec_',
                      pf_name='replay decoder')

    rep_args.add_argument('--distill_across_time',
                          action='store_true',
                          help='The feature vector of an SMNIST sample ' +
                          'contains an "end-of-digit" bit, which in the ' +
                          'original data is only hot for 1 timestep. ' +
                          'However, in replayed data, this information ' +
                          'might be blurry. If this option is ' +
                          'activated, the distillation loss will be ' +
                          'applied to every timestep weighted by the ' +
                          'softmaxed "end-of-digit" feature from the ' +
                          'replayed input. Otherwise, the argmax ' +
                          'timestep of the "end-of-digit" feature is ' +
                          'considered.')

    args = None
    if argv is not None:
        if default:
            warnings.warn('Provided "argv" will be ignored since "default" ' +
                          'option was turned on.')
        args = argv
    if default:
        args = []
    config = parser.parse_args(args=args)
    config.mode = mode

    ### Check argument values!
    cli.check_invalid_argument_usage(config)
    seq.check_invalid_args_sequential(config)

    if config.train_from_scratch:
        # FIXME We could get rid of this warning by properly checkpointing and
        # loading all networks.
        warnings.warn('When training from scratch, only during accuracies ' +
                      'make sense. All other outputs should be ignored!')

    return config