示例#1
0
def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer")

    # main parameters
    parser.add_argument("--dump_path", type=str, default="./dumped/",
                        help="Experiment dump path")
    parser.add_argument("--exp_name", type=str, default="",
                        help="Experiment name")
    parser.add_argument("--save_periodic", type=int, default=0,
                        help="Save the model periodically (0 to disable)")
    parser.add_argument("--exp_id", type=str, default="",
                        help="Experiment ID")

    # float16 / AMP API
    parser.add_argument("--fp16", type=bool_flag, default=False,
                        help="Run model with float16")
    parser.add_argument("--amp", type=int, default=-1,
                        help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.")

    # only use an encoder (use a specific decoder for machine translation)
    parser.add_argument("--encoder_only", type=bool_flag, default=True,
                        help="Only use an encoder")

    # model parameters
    parser.add_argument("--emb_dim", type=int, default=512,
                        help="Embedding layer size")
    parser.add_argument("--n_layers", type=int, default=4,
                        help="Number of Transformer layers")
    parser.add_argument("--n_heads", type=int, default=8,
                        help="Number of Transformer heads")
    parser.add_argument("--dropout", type=float, default=0,
                        help="Dropout")
    parser.add_argument("--attention_dropout", type=float, default=0,
                        help="Dropout in the attention layer")
    parser.add_argument("--gelu_activation", type=bool_flag, default=False,
                        help="Use a GELU activation instead of ReLU")
    parser.add_argument("--share_inout_emb", type=bool_flag, default=True,
                        help="Share input and output embeddings")
    parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False,
                        help="Use sinusoidal embeddings")
    parser.add_argument("--use_lang_emb", type=bool_flag, default=True,
                        help="Use language embedding")
        
    

    # memory parameters
    parser.add_argument("--use_memory", type=bool_flag, default=False,
                        help="Use an external memory")
    if parser.parse_known_args()[0].use_memory:
        HashingMemory.register_args(parser)
        parser.add_argument("--mem_enc_positions", type=str, default="",
                            help="Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)")
        parser.add_argument("--mem_dec_positions", type=str, default="",
                            help="Memory positions in the decoder. Same syntax as `mem_enc_positions`.")

    # adaptive softmax
    parser.add_argument("--asm", type=bool_flag, default=False,
                        help="Use adaptive softmax")
    if parser.parse_known_args()[0].asm:
        parser.add_argument("--asm_cutoffs", type=str, default="8000,20000",
                            help="Adaptive softmax cutoffs")
        parser.add_argument("--asm_div_value", type=float, default=4,
                            help="Adaptive softmax cluster sizes ratio")

    # causal language modeling task parameters
    parser.add_argument("--context_size", type=int, default=0,
                        help="Context size (0 means that the first elements in sequences won't have any context)")

    # masked language modeling task parameters
    parser.add_argument("--word_pred", type=float, default=0.15,
                        help="Fraction of words for which we need to make a prediction")
    parser.add_argument("--sample_alpha", type=float, default=0,
                        help="Exponent for transforming word counts to probabilities (~word2vec sampling)")
    parser.add_argument("--word_mask_keep_rand", type=str, default="0.8,0.1,0.1",
                        help="Fraction of words to mask out / keep / randomize, among the words to predict")

    # input sentence noise
    parser.add_argument("--word_shuffle", type=float, default=0,
                        help="Randomly shuffle input words (0 to disable)")
    parser.add_argument("--word_dropout", type=float, default=0,
                        help="Randomly dropout input words (0 to disable)")
    parser.add_argument("--word_blank", type=float, default=0,
                        help="Randomly blank input words (0 to disable)")

    # data
    parser.add_argument("--data_path", type=str, default="",
                        help="Data path")
    parser.add_argument("--lgs", type=str, default="",
                        help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)")
    parser.add_argument("--max_vocab", type=int, default=-1,
                        help="Maximum vocabulary size (-1 to disable)")
    parser.add_argument("--min_count", type=int, default=0,
                        help="Minimum vocabulary count")
    parser.add_argument("--lg_sampling_factor", type=float, default=-1,
                        help="Language sampling factor")


    # batch parameters
    parser.add_argument("--bptt", type=int, default=256,
                        help="Sequence length")
    parser.add_argument("--max_len", type=int, default=100,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--group_by_size", type=bool_flag, default=True,
                        help="Sort sentences by size during the training")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Number of sentences per batch")
    parser.add_argument("--max_batch_size", type=int, default=0,
                        help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)")
    parser.add_argument("--tokens_per_batch", type=int, default=-1,
                        help="Number of tokens per batch")

    # training parameters
    parser.add_argument("--split_data", type=bool_flag, default=False,
                        help="Split data across workers of a same node")
    parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001",
                        help="Optimizer (SGD / RMSprop / Adam, etc.)")
    parser.add_argument("--clip_grad_norm", type=float, default=5,
                        help="Clip gradients norm (0 to disable)")
    parser.add_argument("--epoch_size", type=int, default=100000,
                        help="Epoch size / evaluation frequency (-1 for parallel data size)")
    parser.add_argument("--max_epoch", type=int, default=100000,
                        help="Maximum epoch size")
    parser.add_argument("--stopping_criterion", type=str, default="",
                        help="Stopping criterion, and number of non-increase before stopping the experiment")
    parser.add_argument("--validation_metrics", type=str, default="",
                        help="Validation metrics")
    parser.add_argument("--accumulate_gradients", type=int, default=1,
                        help="Accumulate model gradients over N iterations (N times larger batch sizes)")

    # training coefficients
    parser.add_argument("--lambda_mlm", type=str, default="1",
                        help="Prediction coefficient (MLM)")
    parser.add_argument("--lambda_clm", type=str, default="1",
                        help="Causal coefficient (LM)")
    parser.add_argument("--lambda_pc", type=str, default="1",
                        help="PC coefficient")
    parser.add_argument("--lambda_ae", type=str, default="1",
                        help="AE coefficient")
    parser.add_argument("--lambda_mt", type=str, default="1",
                        help="MT coefficient")
    parser.add_argument("--lambda_bt", type=str, default="1",
                        help="BT coefficient")

    # training steps
    parser.add_argument("--clm_steps", type=str, default="",
                        help="Causal prediction steps (CLM)")
    parser.add_argument("--mlm_steps", type=str, default="",
                        help="Masked prediction steps (MLM / TLM)")
    parser.add_argument("--mt_steps", type=str, default="",
                        help="Machine translation steps")
    parser.add_argument("--ae_steps", type=str, default="",
                        help="Denoising auto-encoder steps")
    parser.add_argument("--bt_steps", type=str, default="",
                        help="Back-translation steps")
    parser.add_argument("--pc_steps", type=str, default="",
                        help="Parallel classification steps")
    parser.add_argument("--unsclts_steps", type=str, default="",
                        help="'en-EN', 'zh'-'ZH'")

    # reload pretrained embeddings / pretrained model / checkpoint
    parser.add_argument("--reload_emb", type=str, default="",
                        help="Reload pretrained word embeddings")
    parser.add_argument("--reload_model", type=str, default="",
                        help="Reload a pretrained model")
    parser.add_argument("--reload_checkpoint", type=str, default="",
                        help="Reload a checkpoint")

    

    # beam search (for MT only)
    parser.add_argument("--beam_size", type=int, default=1,
                        help="Beam size, default = 1 (greedy decoding)")
    parser.add_argument("--length_penalty", type=float, default=1,
                        help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.")
    parser.add_argument("--early_stopping", type=bool_flag, default=False,
                        help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.")

    # evaluation
    parser.add_argument("--eval_bleu", type=bool_flag, default=False,
                        help="Evaluate BLEU score during MT training")
    parser.add_argument("--eval_only", type=bool_flag, default=False,
                        help="Only run evaluations")

    # debug
    parser.add_argument("--debug_train", type=bool_flag, default=False,
                        help="Use valid sets for train sets (faster loading)")
    parser.add_argument("--debug_slurm", type=bool_flag, default=False,
                        help="Debug multi-GPU / multi-node within a SLURM job")
    parser.add_argument("--debug", help="Enable all debug flags",
                        action="store_true")

    # multi-gpu / multi-node
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="Multi-GPU - Local rank")
    parser.add_argument("--master_port", type=int, default=-1,
                        help="Master port (for multi-node SLURM jobs)")



    ########## added by chiamin ##########
    
    # general
    parser.add_argument("--share_encdec_emb", type=bool_flag, default=False, help="Share encoder and decoder word embeddings")
    
    parser.add_argument("--eval_rouge", type=bool_flag, default=False, help="Evaluate ROUGE-1 F1 score during TS training")
    parser.add_argument("--label_smoothing", type=float, default=0., help="Label smoothing loss (0 to disable)")

    parser.add_argument("--separated_vocab", type=bool_flag, default=False, help"Use different vocabulary/dictionary for encoder and decoder. If True, shared_encdec_emb will always be False")


    # separated vocabulary/dictionary
    parser.add_argument("--src_max_vocab", type=int, default=-1,
                        help="Maximum source vocabulary size (-1 to disable)")
    parser.add_argument("--tgt_max_vocab", type=int, default=-1,
                        help="Maximum target vocabulary size (-1 to disable)")
    parser.add_argument("--src_min_count", type=int, default=0,
                        help="Minimum source vocabulary count")
    parser.add_argument("--tgt_min_count", type=int, default=0,
                        help="Minimum target vocabulary count")
    

    # clts-xencoder
    parser.add_argument("--use_xencoder", type=bool_flag, default=False, help="use cross-lingual encoder")
    parser.add_argument("--reload_xencoder", type=str, default="", help="Reload pretrained xlm (cross-lingual encoder). Used in clts-xencoder")
    parser.add_argument("--ts_emb_dim", type=int, default=512,
                        help="text summarization embedding layer size")
    parser.add_argument("--ts_n_layers", type=int, default=4,
                        help="Number of Transformer layers")
    parser.add_argument("--ts_n_heads", type=int, default=8,
                        help="Number of Transformer heads")
    parser.add_argument("--ts_dropout", type=float, default=0,
                        help="Dropout")
    parser.add_argument("--ts_attention_dropout", type=float, default=0,
                        help="Dropout in the attention layer")
    parser.add_argument("--ts_gelu_activation", type=bool_flag, default=False,
                        help="Use a GELU activation instead of ReLU")
    parser.add_argument("--xencoder_optimizer", type=str, default="adam,lr=0.0001",
                        help="Cross-lingual Optimizer (SGD / RMSprop / Adam, etc.)")


    # clts-elmo
    parser.add_argument("--reload_elmo", type=str, default="", help="Reload pretrained elmo. Used in clts-elmo evaluation")
    parser.add_argument("--elmo_tune_lm", type=bool_flag, default=True, help="")
    parser.add_argument("--elmo_weights_dropout", type=float, default=0.0, help="")
    parser.add_argument("--elmo_final_dropout", type=float, default=0.0, help="")
    parser.add_argument("--elmo_layer_norm", type=bool_flag, default=True, help="")
    parser.add_argument("--elmo_affine_layer_norm", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_apply_softmax", type=bool_flag, default=True, help="")
    parser.add_argument("--elmo_channelwise_weights", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_scaled_sigmoid", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_individual_norms", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_channelwise_norm", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_init_gamma", type=float, default=1.0, help="")
    parser.add_argument("--elmo_ltn", type=bool_flag, default=False, help="")
    parser.add_argument("--elmo_ltn_dims", type=str, default="", help="")
    parser.add_argument("--elmo_train_gamma", type=bool_flag, default=True, help="")
   
        
    ######################################

    return parser
示例#2
0
def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer")

    # main parameters
    parser.add_argument("--dump_path",
                        type=str,
                        default="./dumped/",
                        help="Experiment dump path")
    parser.add_argument("--exp_name",
                        type=str,
                        default="",
                        help="Experiment name")
    parser.add_argument("--save_periodic",
                        type=int,
                        default=0,
                        help="Save the model periodically (0 to disable)")
    parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")

    # float16 / AMP API
    parser.add_argument("--fp16",
                        type=bool_flag,
                        default=False,
                        help="Run model with float16")
    parser.add_argument(
        "--amp",
        type=int,
        default=-1,
        help=
        "Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable."
    )

    # only use an encoder (use a specific decoder for machine translation)
    parser.add_argument("--encoder_only",
                        type=bool_flag,
                        default=True,
                        help="Only use an encoder")

    # model parameters
    parser.add_argument("--emb_dim",
                        type=int,
                        default=512,
                        help="Embedding layer size")
    parser.add_argument("--n_layers",
                        type=int,
                        default=4,
                        help="Number of Transformer layers")
    parser.add_argument("--n_heads",
                        type=int,
                        default=8,
                        help="Number of Transformer heads")
    parser.add_argument("--dropout", type=float, default=0, help="Dropout")
    parser.add_argument("--attention_dropout",
                        type=float,
                        default=0,
                        help="Dropout in the attention layer")
    parser.add_argument("--gelu_activation",
                        type=bool_flag,
                        default=False,
                        help="Use a GELU activation instead of ReLU")
    parser.add_argument("--share_inout_emb",
                        type=bool_flag,
                        default=True,
                        help="Share input and output embeddings")
    parser.add_argument("--sinusoidal_embeddings",
                        type=bool_flag,
                        default=False,
                        help="Use sinusoidal embeddings")
    parser.add_argument("--use_lang_emb",
                        type=bool_flag,
                        default=True,
                        help="Use language embedding")

    # memory parameters
    parser.add_argument("--use_memory",
                        type=bool_flag,
                        default=False,
                        help="Use an external memory")
    if parser.parse_known_args()[0].use_memory:
        HashingMemory.register_args(parser)
        parser.add_argument(
            "--mem_enc_positions",
            type=str,
            default="",
            help=
            "Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)"
        )
        parser.add_argument(
            "--mem_dec_positions",
            type=str,
            default="",
            help=
            "Memory positions in the decoder. Same syntax as `mem_enc_positions`."
        )

    # adaptive softmax
    parser.add_argument("--asm",
                        type=bool_flag,
                        default=False,
                        help="Use adaptive softmax")
    if parser.parse_known_args()[0].asm:
        parser.add_argument("--asm_cutoffs",
                            type=str,
                            default="8000,20000",
                            help="Adaptive softmax cutoffs")
        parser.add_argument("--asm_div_value",
                            type=float,
                            default=4,
                            help="Adaptive softmax cluster sizes ratio")

    # causal language modeling task parameters
    parser.add_argument(
        "--context_size",
        type=int,
        default=0,
        help=
        "Context size (0 means that the first elements in sequences won't have any context)"
    )

    # masked language modeling task parameters
    parser.add_argument(
        "--word_pred",
        type=float,
        default=0.15,
        help="Fraction of words for which we need to make a prediction")
    parser.add_argument(
        "--sample_alpha",
        type=float,
        default=0,
        help=
        "Exponent for transforming word counts to probabilities (~word2vec sampling)"
    )
    parser.add_argument(
        "--word_mask_keep_rand",
        type=str,
        default="0.8,0.1,0.1",
        help=
        "Fraction of words to mask out / keep / randomize, among the words to predict"
    )

    # input sentence noise
    parser.add_argument("--word_shuffle",
                        type=float,
                        default=0,
                        help="Randomly shuffle input words (0 to disable)")
    parser.add_argument("--word_dropout",
                        type=float,
                        default=0,
                        help="Randomly dropout input words (0 to disable)")
    parser.add_argument("--word_blank",
                        type=float,
                        default=0,
                        help="Randomly blank input words (0 to disable)")

    # data
    parser.add_argument("--data_path", type=str, default="", help="Data path")
    parser.add_argument("--lgs",
                        type=str,
                        default="",
                        help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)")
    parser.add_argument("--max_vocab",
                        type=int,
                        default=-1,
                        help="Maximum vocabulary size (-1 to disable)")
    parser.add_argument("--min_count",
                        type=int,
                        default=0,
                        help="Minimum vocabulary count")
    parser.add_argument("--lg_sampling_factor",
                        type=float,
                        default=-1,
                        help="Language sampling factor")

    # batch parameters
    parser.add_argument("--bptt",
                        type=int,
                        default=256,
                        help="Sequence length")
    parser.add_argument("--max_len",
                        type=int,
                        default=100,
                        help="Maximum length of sentences (after BPE)")
    parser.add_argument("--group_by_size",
                        type=bool_flag,
                        default=True,
                        help="Sort sentences by size during the training")
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Number of sentences per batch")
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=0,
        help=
        "Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)"
    )
    parser.add_argument("--tokens_per_batch",
                        type=int,
                        default=-1,
                        help="Number of tokens per batch")

    # training parameters
    parser.add_argument("--split_data",
                        type=bool_flag,
                        default=False,
                        help="Split data across workers of a same node")
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adam,lr=0.0001",
        help=
        "Optimizer : Adam / AdamInverseSqrtWithWarmup / AdamCosineWithWarmup / \
                              Adadelta / Adagrad / Adamax / ASGD / SGD / RMSprop / Adam / Rprop)"
    )
    parser.add_argument("--clip_grad_norm",
                        type=float,
                        default=5,
                        help="Clip gradients norm (0 to disable)")
    parser.add_argument(
        "--epoch_size",
        type=int,
        default=100000,
        help="Epoch size / evaluation frequency (-1 for parallel data size)")
    parser.add_argument("--max_epoch",
                        type=int,
                        default=100000,
                        help="Maximum epoch size")
    parser.add_argument(
        "--stopping_criterion",
        type=str,
        default="",
        help=
        "Stopping criterion, and number of non-increase before stopping the experiment"
    )
    parser.add_argument("--validation_metrics",
                        type=str,
                        default="",
                        help="Validation metrics")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help=
        "Accumulate model gradients over N iterations (N times larger batch sizes)"
    )

    # training coefficients
    parser.add_argument("--lambda_mlm",
                        type=str,
                        default="1",
                        help="Prediction coefficient (MLM)")
    parser.add_argument("--lambda_clm",
                        type=str,
                        default="1",
                        help="Causal coefficient (LM)")
    parser.add_argument("--lambda_pc",
                        type=str,
                        default="1",
                        help="PC coefficient")
    parser.add_argument("--lambda_ae",
                        type=str,
                        default="1",
                        help="AE coefficient")
    parser.add_argument("--lambda_mt",
                        type=str,
                        default="1",
                        help="MT coefficient")
    parser.add_argument("--lambda_bt",
                        type=str,
                        default="1",
                        help="BT coefficient")

    # training steps
    parser.add_argument("--clm_steps",
                        type=str,
                        default="",
                        help="Causal prediction steps (CLM)")
    parser.add_argument("--mlm_steps",
                        type=str,
                        default="",
                        help="Masked prediction steps (MLM / TLM)")
    parser.add_argument("--mt_steps",
                        type=str,
                        default="",
                        help="Machine translation steps")
    parser.add_argument("--ae_steps",
                        type=str,
                        default="",
                        help="Denoising auto-encoder steps")
    parser.add_argument("--bt_steps",
                        type=str,
                        default="",
                        help="Back-translation steps")
    parser.add_argument("--pc_steps",
                        type=str,
                        default="",
                        help="Parallel classification steps")

    # reload pretrained embeddings / pretrained model / checkpoint
    parser.add_argument("--reload_emb",
                        type=str,
                        default="",
                        help="Reload pretrained word embeddings")
    parser.add_argument("--reload_model",
                        type=str,
                        default="",
                        help="Reload a pretrained model")
    parser.add_argument("--reload_checkpoint",
                        type=str,
                        default="",
                        help="Reload a checkpoint")

    # beam search (for MT only)
    parser.add_argument("--beam_size",
                        type=int,
                        default=1,
                        help="Beam size, default = 1 (greedy decoding)")
    parser.add_argument(
        "--length_penalty",
        type=float,
        default=1,
        help=
        "Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones."
    )
    parser.add_argument(
        "--early_stopping",
        type=bool_flag,
        default=False,
        help=
        "Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores."
    )

    # evaluation
    parser.add_argument("--eval_bleu",
                        type=bool_flag,
                        default=False,
                        help="Evaluate BLEU score during MT training")
    parser.add_argument("--eval_only",
                        type=bool_flag,
                        default=False,
                        help="Only run evaluations")

    # debug
    parser.add_argument("--debug_train",
                        type=bool_flag,
                        default=False,
                        help="Use valid sets for train sets (faster loading)")
    parser.add_argument("--debug_slurm",
                        type=bool_flag,
                        default=False,
                        help="Debug multi-GPU / multi-node within a SLURM job")
    parser.add_argument("--debug",
                        help="Enable all debug flags",
                        action="store_true")

    # multi-gpu / multi-node
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="Multi-GPU - Local rank")
    parser.add_argument("--master_port",
                        type=int,
                        default=-1,
                        help="Master port (for multi-node SLURM jobs)")

    # our
    # These three parameters will always be rounded to an integer number of batches, so don't be surprised if you see different values than the ones provided.
    parser.add_argument("--train_n_samples",
                        type=int,
                        default=-1,
                        help="Just consider train_n_sample train data")
    parser.add_argument("--valid_n_samples",
                        type=int,
                        default=-1,
                        help="Just consider valid_n_sample validation data")
    parser.add_argument("--test_n_samples",
                        type=int,
                        default=-1,
                        help="Just consider test_n_sample test data for")
    parser.add_argument("--remove_long_sentences_train",
                        type=bool_flag,
                        default=False,
                        help="remove long sentences in train dataset")
    parser.add_argument("--remove_long_sentences_valid",
                        type=bool_flag,
                        default=False,
                        help="remove long sentences in valid dataset")
    parser.add_argument("--remove_long_sentences_test",
                        type=bool_flag,
                        default=False,
                        help="remove long sentences in test dataset")

    parser.add_argument("--same_data_path", type=bool_flag, default=True,
                        help="In the case of metalearning, this parameter, when passed to False, the data are" \
                            "searched for each task in a folder with the name of the task and located in data_path otherwise all the data are searched in data_path.")

    parser.add_argument("--config_file", type=str, default="", help="")

    parser.add_argument("--log_file_prefix", type=str, default="",
                        help="Log file prefix. Name of the language to be evaluated in the case of the" \
                              "evaluation of one LM on another.")

    parser.add_argument(
        "--aggregation_metrics",
        type=str,
        default="",
        help="name_metric1=mean(m1,m2,...);name_metric2=sum(m4,m5,...);...")

    parser.add_argument("--eval_tasks", type=str, default="",
                        help="During metalearning we need tasks on which to refine and evaluate the model after each epoch." \
                              "task_name:train_n_samples,..."
                            )
    # TIM
    parser.add_argument("--tim_layers_pos",
                        type=str,
                        default="",
                        help="tim layers position : 0,1,5 for example")
    parser.add_argument("--use_group_comm", type=bool_flag, default=True)
    parser.add_argument("--use_mine", type=bool_flag, default=False)
    if parser.parse_known_args()[0].tim_layers_pos:
        # Transformers with Independent Mechanisms (TIM) model parameters
        parser.add_argument("--n_s",
                            type=int,
                            default=2,
                            help="number of mechanisms")
        parser.add_argument("--H",
                            type=int,
                            default=8,
                            help="number of heads for self-attention")
        parser.add_argument(
            "--H_c",
            type=int,
            default=8,
            help="number of heads for inter-mechanism attention")
        parser.add_argument("--custom_mha", type=bool_flag, default=False)
        #if parser.parse_known_args()[0].custom_mha:
        parser.add_argument("--d_k",
                            type=int,
                            default=512,
                            help="key dimension")
        parser.add_argument("--d_v",
                            type=int,
                            default=512,
                            help="value dimension")

    parser.add_argument(
        "--dim_feedforward",
        type=int,
        default=512 * 4,
        help="Dimension of Intermediate Layers in Positionwise Feedforward Net"
    )

    parser.add_argument(
        "--log_interval",
        type=int,
        default=-1,
        help=
        "Interval (number of steps) between two displays : batch_size by default"
    )
    parser.add_argument("--device", type=str, default="", help="cpu/cuda")
    parser.add_argument("--random_seed",
                        type=int,
                        default=0,
                        help="random seed for reproductibility")

    return parser