Beispiel #1
0
def train_recognition(model_name, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001,
                      optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True,
                      use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1,
                      use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True,
                      use_word_input=False, word_embed_type=None, word_embed_trainable=True,
                      use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True,
                      use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False,
                      callbacks_to_add=None, overwrite=False, swa_start=3, early_stopping_patience=3, **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.label_schema = label_schema
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    config.use_char_input = use_char_input
    if embed_type:
        config.embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type))
        config.embed_trainable = embed_trainable
        config.embed_dim = config.embeddings.shape[1]
    else:
        config.embeddings = None
        config.embed_trainable = True

    config.callbacks_to_add = callbacks_to_add or ['modelcheckpoint', 'earlystopping']
    if 'swa' in config.callbacks_to_add:
        config.swa_start = swa_start
        config.early_stopping_patience = early_stopping_patience

    config.vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))

    if config.use_char_input:
        config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(model_name, config.embed_type if config.embed_type else 'random',
                                                        'tune' if config.embed_trainable else 'fix', batch_size,
                                                        optimizer_type, learning_rate, label_schema)
    else:
        config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema)
    if config.n_epoch != 50:
        config.exp_name += '_{}'.format(config.n_epoch)
    if kwargs:
        config.exp_name += '_' + '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')
    config.exp_name += callback_str

    config.use_bert_input = use_bert_input
    config.bert_type = bert_type
    config.bert_trainable = bert_trainable
    config.bert_layer_num = bert_layer_num
    assert config.use_char_input or config.use_bert_input
    if config.use_bert_input:
        config.exp_name += '_{}_layer_{}_{}'.format(bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix')
    config.use_bichar_input = use_bichar_input
    if config.use_bichar_input:
        config.bichar_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar'))
        config.bichar_vocab_size = len(config.bichar_vocab) + 2
        if bichar_embed_type:
            config.bichar_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                               type=bichar_embed_type))
            config.bichar_embed_trainable = bichar_embed_trainable
            config.bichar_embed_dim = config.bichar_embeddings.shape[1]
        else:
            config.bichar_embeddings = None
            config.bichar_embed_trainable = True
        config.exp_name += '_bichar_{}_{}'.format(bichar_embed_type if bichar_embed_type else 'random',
                                                  'tune' if config.bichar_embed_trainable else 'fix')
    config.use_word_input = use_word_input
    if config.use_word_input:
        config.word_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word'))
        config.word_vocab_size = len(config.word_vocab) + 2
        if word_embed_type:
            config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                             type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
        config.exp_name += '_word_{}_{}'.format(word_embed_type if word_embed_type else 'random',
                                                'tune' if config.word_embed_trainable else 'fix')
    config.use_charpos_input = use_charpos_input
    if config.use_charpos_input:
        config.charpos_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos'))
        config.charpos_vocab_size = len(config.charpos_vocab) + 2
        if charpos_embed_type:
            config.charpos_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                                type=charpos_embed_type))
            config.charpos_embed_trainable = charpos_embed_trainable
            config.charpos_embed_dim = config.charpos_embeddings.shape[1]
        else:
            config.charpos_embeddings = None
            config.charpos_embed_trainable = True
        config.exp_name += '_charpos_{}_{}'.format(charpos_embed_type if charpos_embed_type else 'random',
                                                   'tune' if config.charpos_embed_trainable else 'fix')
    config.use_softword_input = use_softword_input
    if config.use_softword_input:
        config.exp_name += '_softword'
    config.use_dictfeat_input = use_dictfeat_input
    if config.use_dictfeat_input:
        config.exp_name += '_dictfeat'
    config.use_maxmatch_input = use_maxmatch_input
    if config.use_maxmatch_input:
        config.exp_name += '_maxmatch'

    # logger to log output of training process
    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch,
                 'learning_rate': learning_rate, 'other_params': kwargs}

    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
    model = RecognitionModel(config, **kwargs)

    train_data_type, dev_data_type = 'train', 'dev'
    train_generator = RecognitionDataGenerator(train_data_type, config.batch_size, config.label_schema,
                                               config.label_to_one_hot[config.label_schema],
                                               config.vocab if config.use_char_input else None,
                                               config.bert_vocab_file(config.bert_type) if config.use_bert_input else None,
                                               config.bert_seq_len, config.bichar_vocab, config.word_vocab,
                                               config.use_word_input, config.charpos_vocab, config.use_softword_input,
                                               config.use_dictfeat_input, config.use_maxmatch_input)
    valid_generator = RecognitionDataGenerator(dev_data_type, config.batch_size, config.label_schema,
                                               config.label_to_one_hot[config.label_schema],
                                               config.vocab if config.use_char_input else None,
                                               config.bert_vocab_file(config.bert_type) if config.use_bert_input else None,
                                               config.bert_seq_len, config.bichar_vocab, config.word_vocab,
                                               config.use_word_input, config.charpos_vocab, config.use_softword_input,
                                               config.use_dictfeat_input, config.use_maxmatch_input)

    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()
        model.train(train_generator, valid_generator)
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

    model.load_best_model()

    print('Logging Info - Evaluate over valid data:')
    r, p, f1 = model.evaluate(valid_generator)
    train_log['dev_performance'] = (r, p, f1)

    swa_type = None
    if 'swa' in config.callbacks_to_add:
        swa_type = 'swa'
    elif 'swa_clr' in config.callbacks_to_add:
        swa_type = 'swa_clr'
    if swa_type:
        model.load_swa_model(swa_type)
        print('Logging Info - Evaluate over valid data based on swa model:')
        r, p, f1 = model.evaluate(valid_generator)
        train_log['swa_dev_performance'] = (r, p, f1)

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_er'), log=train_log, mode='a')

    del model
    gc.collect()
    K.clear_session()
Beispiel #2
0
def recognition(model_name,
                predict_log,
                label_schema='BIOES',
                batch_size=32,
                n_epoch=50,
                learning_rate=0.001,
                optimizer_type='adam',
                use_char_input=True,
                embed_type=None,
                embed_trainable=True,
                use_bert_input=False,
                bert_type='bert',
                bert_trainable=True,
                bert_layer_num=1,
                use_bichar_input=False,
                bichar_embed_type=None,
                bichar_embed_trainable=True,
                use_word_input=False,
                word_embed_type=None,
                word_embed_trainable=True,
                use_charpos_input=False,
                charpos_embed_type=None,
                charpos_embed_trainable=True,
                use_softword_input=False,
                use_dictfeat_input=False,
                use_maxmatch_input=False,
                callbacks_to_add=None,
                swa_type=None,
                predict_on_dev=True,
                predict_on_final_test=True,
                **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.label_schema = label_schema
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    config.use_char_input = use_char_input
    if embed_type:
        config.embeddings = np.load(
            format_filename(PROCESSED_DATA_DIR,
                            EMBEDDING_MATRIX_TEMPLATE,
                            type=embed_type))
        config.embed_trainable = embed_trainable
        config.embed_dim = config.embeddings.shape[1]
    else:
        config.embeddings = None
        config.embed_trainable = True
    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]

    config.vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(
        format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))

    if config.use_char_input:
        config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(
            model_name, config.embed_type if config.embed_type else 'random',
            'tune' if config.embed_trainable else 'fix', batch_size,
            optimizer_type, learning_rate, label_schema)
    else:
        config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size,
                                                  optimizer_type,
                                                  learning_rate, label_schema)
    if kwargs:
        config.exp_name += '_' + '_'.join(
            [str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str

    config.use_bert_input = use_bert_input
    config.bert_type = bert_type
    config.bert_trainable = bert_trainable
    config.bert_layer_num = bert_layer_num
    assert config.use_char_input or config.use_bert_input
    if config.use_bert_input:
        config.exp_name += '_{}_layer_{}_{}'.format(
            bert_type, bert_layer_num,
            'tune' if config.bert_trainable else 'fix')
    config.use_bichar_input = use_bichar_input
    if config.use_bichar_input:
        config.bichar_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='bichar'))
        config.bichar_vocab_size = len(config.bichar_vocab) + 2
        if bichar_embed_type:
            config.bichar_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=bichar_embed_type))
            config.bichar_embed_trainable = bichar_embed_trainable
            config.bichar_embed_dim = config.bichar_embeddings.shape[1]
        else:
            config.bichar_embeddings = None
            config.bichar_embed_trainable = True
        config.exp_name += '_bichar_{}_{}'.format(
            bichar_embed_type if bichar_embed_type else 'random',
            'tune' if config.bichar_embed_trainable else 'fix')
    config.use_word_input = use_word_input
    if config.use_word_input:
        config.word_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='word'))
        config.word_vocab_size = len(config.word_vocab) + 2
        if word_embed_type:
            config.word_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
        config.exp_name += '_word_{}_{}'.format(
            word_embed_type if word_embed_type else 'random',
            'tune' if config.word_embed_trainable else 'fix')
    config.use_charpos_input = use_charpos_input
    if config.use_charpos_input:
        config.charpos_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='charpos'))
        config.charpos_vocab_size = len(config.charpos_vocab) + 2
        if charpos_embed_type:
            config.charpos_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=charpos_embed_type))
            config.charpos_embed_trainable = charpos_embed_trainable
            config.charpos_embed_dim = config.charpos_embeddings.shape[1]
        else:
            config.charpos_embeddings = None
            config.charpos_embed_trainable = True
        config.exp_name += '_charpos_{}_{}'.format(
            charpos_embed_type if charpos_embed_type else 'random',
            'tune' if config.charpos_embed_trainable else 'fix')
    config.use_softword_input = use_softword_input
    if config.use_softword_input:
        config.exp_name += '_softword'
    config.use_dictfeat_input = use_dictfeat_input
    if config.use_dictfeat_input:
        config.exp_name += '_dictfeat'
    config.use_maxmatch_input = use_maxmatch_input
    if config.use_maxmatch_input:
        config.exp_name += '_maxmatch'

    # logger to log output of training process
    predict_log.update({
        'er_exp_name': config.exp_name,
        'er_batch_size': batch_size,
        'er_optimizer': optimizer_type,
        'er_epoch': n_epoch,
        'er_learning_rate': learning_rate,
        'er_other_params': kwargs
    })

    print('Logging Info - Experiment: %s' % config.exp_name)
    model = RecognitionModel(config, **kwargs)

    dev_data_type = 'dev'
    if predict_on_final_test:
        test_data_type = 'test_final'
    else:
        test_data_type = 'test'
    valid_generator = RecognitionDataGenerator(
        dev_data_type, config.batch_size, config.label_schema,
        config.label_to_one_hot[config.label_schema],
        config.vocab if config.use_char_input else None,
        config.bert_vocab_file(config.bert_type) if config.use_bert_input else
        None, config.bert_seq_len, config.bichar_vocab, config.word_vocab,
        config.use_word_input, config.charpos_vocab, config.use_softword_input,
        config.use_dictfeat_input, config.use_maxmatch_input)
    test_generator = RecognitionDataGenerator(
        test_data_type, config.batch_size, config.label_schema,
        config.label_to_one_hot[config.label_schema],
        config.vocab if config.use_char_input else None,
        config.bert_vocab_file(config.bert_type) if config.use_bert_input else
        None, config.bert_seq_len, config.bichar_vocab, config.word_vocab,
        config.use_word_input, config.charpos_vocab, config.use_softword_input,
        config.use_dictfeat_input, config.use_maxmatch_input)

    model_save_path = os.path.join(config.checkpoint_dir,
                                   '{}.hdf5'.format(config.exp_name))
    if not os.path.exists(model_save_path):
        raise FileNotFoundError(
            'Recognition model not exist: {}'.format(model_save_path))

    if swa_type is None:
        model.load_best_model()
    elif 'swa' in callbacks_to_add:
        model.load_swa_model(swa_type)
        predict_log['er_exp_name'] += '_{}'.format(swa_type)

    if predict_on_dev:
        print('Logging Info - Generate submission for valid data:')
        dev_pred_mentions = model.predict(valid_generator)
    else:
        dev_pred_mentions = None
    print('Logging Info - Generate submission for test data:')
    test_pred_mentions = model.predict(test_generator)

    return dev_pred_mentions, test_pred_mentions
def prepare_config(model_type='bert-base-uncased',
                   input_type='name_desc',
                   use_multi_task=True,
                   use_harl=False,
                   use_hal=False,
                   cate_embed_dim=100,
                   use_word_input=False,
                   word_embed_type='w2v',
                   word_embed_trainable=True,
                   word_embed_dim=300,
                   use_bert_input=True,
                   bert_trainable=True,
                   use_bert_type='pooler',
                   n_last_hidden_layer=0,
                   dense_after_bert=True,
                   use_pair_input=True,
                   max_len=None,
                   share_father_pred='no',
                   use_mask_for_cate2=False,
                   use_mask_for_cate3=True,
                   cate3_mask_type='cate1',
                   cate1_loss_weight=1.,
                   cate2_loss_weight=1.,
                   cate3_loss_weight=1.,
                   batch_size=32,
                   predict_batch_size=32,
                   n_epoch=50,
                   learning_rate=2e-5,
                   optimizer='adam',
                   use_focal_loss=False,
                   callbacks_to_add=None,
                   swa_start=15,
                   early_stopping_patience=5,
                   max_lr=6e-5,
                   min_lr=1e-5,
                   train_on_cv=False,
                   cv_random_state=42,
                   cv_fold=5,
                   cv_index=0,
                   exchange_pair=False,
                   exchange_threshold=0.1,
                   use_pseudo_label=False,
                   pseudo_path=None,
                   pseudo_random_state=42,
                   pseudo_rate=0.1,
                   pseudo_index=0,
                   pseudo_name=None,
                   exp_name=None):
    config = ModelConfig()
    config.model_type = model_type
    config.input_type = input_type
    config.use_multi_task = use_multi_task
    config.use_harl = use_harl
    config.use_hal = use_hal
    assert not (config.use_harl and config.use_hal)
    config.cate_embed_dim = cate_embed_dim

    config.use_word_input = use_word_input
    config.word_embed_type = word_embed_type
    if config.use_word_input:
        if word_embed_type:
            config.word_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
            config.word_embed_dim = word_embed_dim
        config.word_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='word'))
        config.word_vocab_size = len(
            config.word_vocab) + 2  # 0: mask, 1: padding
    else:
        config.word_vocab = None

    config.use_bert_input = use_bert_input
    config.bert_trainable = bert_trainable
    if config.use_bert_input:
        config.use_bert_type = use_bert_type
        config.dense_after_bert = dense_after_bert
        if config.use_bert_type in ['hidden', 'hidden_pooler'] or \
                (config.use_multi_task and (config.use_harl or config.use_hal)):
            config.output_hidden_state = True
            config.n_last_hidden_layer = n_last_hidden_layer
        else:
            config.output_hidden_state = False
            config.n_last_hidden_layer = 0
    if config.input_type == 'name_desc':
        config.use_pair_input = use_pair_input
    else:
        config.use_pair_input = False

    if config.use_bert_input and max_len is None:
        config.max_len = MAX_LEN[input_type]
    else:
        config.max_len = max_len

    config.cate1_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate1'))
    config.cate2_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate2'))
    config.cate3_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate3'))
    config.all_cate_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        level='all_cate'))
    config.idx2cate1 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1'))
    config.idx2cate2 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2'))
    config.idx2cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3'))
    config.idx2all_cate = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        IDX2TOKEN_TEMPLATE,
                        level='all_cate'))
    config.cate1_to_cate2 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT))
    config.cate2_to_cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT))
    config.cate1_to_cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT))
    config.cate1_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT))
    config.cate2_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT))
    config.cate3_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT))
    config.n_cate1 = len(config.cate1_vocab)
    config.n_cate2 = len(config.cate2_vocab)
    config.n_cate3 = len(config.cate3_vocab)
    config.n_all_cate = len(config.all_cate_vocab)

    if config.use_multi_task and (config.use_harl or config.use_hal):
        config.share_father_pred = 'no'
        config.use_mask_for_cate2 = False
        config.use_mask_for_cate3 = False
        config.cate3_mask_type = None
    else:
        config.share_father_pred = share_father_pred
        config.use_mask_for_cate2 = use_mask_for_cate2
        config.use_mask_for_cate3 = use_mask_for_cate3
        config.cate3_mask_type = cate3_mask_type
        # if config.use_mask_for_cate2:
        if config.use_mask_for_cate3:
            if config.cate3_mask_type == 'cate1':
                config.cate_to_cate3 = pickle_load(
                    format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT))
            elif config.cate3_mask_type == 'cate2':
                config.cate_to_cate3 = pickle_load(
                    format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT))
    config.cate1_loss_weight = cate1_loss_weight
    config.cate2_loss_weight = cate2_loss_weight
    config.cate3_loss_weight = cate3_loss_weight

    config.batch_size = batch_size
    config.predict_batch_size = predict_batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = optimizer
    config.learning_rate = learning_rate
    config.use_focal_loss = use_focal_loss
    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]
    if 'swa' in config.callbacks_to_add:
        config.swa_start = swa_start
        config.early_stopping_patience = early_stopping_patience
    for lr_scheduler in [
            'clr', 'sgdr', 'clr_1', 'clr_2', 'warm_up', 'swa_clr'
    ]:
        if lr_scheduler in config.callbacks_to_add:
            config.max_lr = max_lr
            config.min_lr = min_lr

    config.train_on_cv = train_on_cv
    if config.train_on_cv:
        config.cv_random_state = cv_random_state
        config.cv_fold = cv_fold
        config.cv_index = cv_index

    config.exchange_pair = exchange_pair
    if config.exchange_pair:
        config.exchange_threshold = exchange_threshold

    config.use_pseudo_label = use_pseudo_label
    if config.use_pseudo_label:
        config.pseudo_path = pseudo_path
        config.pseudo_random_state = pseudo_random_state
        config.pseudo_rate = pseudo_rate
        config.pseudo_index = pseudo_index

    # build experiment name from parameter configuration
    config.exp_name = f'{config.model_type}_{config.input_type}'
    if config.use_pair_input:
        config.exp_name += '_pair'
    config.exp_name += f'_len_{config.max_len}'
    if config.use_word_input:
        config.exp_name += f"_word_{config.word_embed_type}_{'tune' if config.word_embed_trainable else 'fix'}"
    if config.use_bert_input:
        config.exp_name += f"_bert_{config.use_bert_type}_{'tune' if config.bert_trainable else 'fix'}"
        if config.output_hidden_state:
            config.exp_name += f'_hid_{config.n_last_hidden_layer}'
        if config.dense_after_bert:
            config.exp_name += '_dense'
    if config.use_multi_task:
        if config.use_harl:
            config.exp_name += f'_harl_{config.cate_embed_dim}'
        elif config.use_hal:
            config.exp_name += f'_hal_{config.cate_embed_dim}'
        config.exp_name += f'_{config.cate1_loss_weight}_{config.cate2_loss_weight}_{config.cate3_loss_weight}'
    else:
        config.exp_name += f'_not_multi_task'
    if config.share_father_pred in ['after', 'before']:
        config.exp_name += f'_{config.share_father_pred}'
    if config.use_mask_for_cate2:
        config.exp_name += f'_mask_cate2'
    if config.use_mask_for_cate3:
        config.exp_name += f'_mask_cate3_with_{config.cate3_mask_type}'
    if config.use_focal_loss:
        config.exp_name += f'_focal'
    config.exp_name += f'_{config.optimizer}_{config.learning_rate}_{config.batch_size}_{config.n_epoch}'
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str
    if config.train_on_cv:
        config.exp_name += f'_{config.cv_random_state}_{config.cv_fold}_{config.cv_index}'
    if config.exchange_pair:
        config.exp_name += f"_ex_pair_{config.exchange_threshold}"
    if config.use_pseudo_label:
        if pseudo_name:
            config.exp_name += f"_{pseudo_name}_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"
        elif 'dev' in config.pseudo_path:
            config.exp_name += f"_dev_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"
        else:
            config.exp_name += f"_test_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"

    if exp_name:
        config.exp_name = exp_name

    return config