def train_recognition(model_name, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True, use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1, use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True, use_word_input=False, word_embed_type=None, word_embed_trainable=True, use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True, use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False, callbacks_to_add=None, overwrite=False, swa_start=3, early_stopping_patience=3, **kwargs): config = ModelConfig() config.model_name = model_name config.label_schema = label_schema config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type config.use_char_input = use_char_input if embed_type: config.embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable config.embed_dim = config.embeddings.shape[1] else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or ['modelcheckpoint', 'earlystopping'] if 'swa' in config.callbacks_to_add: config.swa_start = swa_start config.early_stopping_patience = early_stopping_patience config.vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load(format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) if config.use_char_input: config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(model_name, config.embed_type if config.embed_type else 'random', 'tune' if config.embed_trainable else 'fix', batch_size, optimizer_type, learning_rate, label_schema) else: config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema) if config.n_epoch != 50: config.exp_name += '_{}'.format(config.n_epoch) if kwargs: config.exp_name += '_' + '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str config.use_bert_input = use_bert_input config.bert_type = bert_type config.bert_trainable = bert_trainable config.bert_layer_num = bert_layer_num assert config.use_char_input or config.use_bert_input if config.use_bert_input: config.exp_name += '_{}_layer_{}_{}'.format(bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix') config.use_bichar_input = use_bichar_input if config.use_bichar_input: config.bichar_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar')) config.bichar_vocab_size = len(config.bichar_vocab) + 2 if bichar_embed_type: config.bichar_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=bichar_embed_type)) config.bichar_embed_trainable = bichar_embed_trainable config.bichar_embed_dim = config.bichar_embeddings.shape[1] else: config.bichar_embeddings = None config.bichar_embed_trainable = True config.exp_name += '_bichar_{}_{}'.format(bichar_embed_type if bichar_embed_type else 'random', 'tune' if config.bichar_embed_trainable else 'fix') config.use_word_input = use_word_input if config.use_word_input: config.word_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len(config.word_vocab) + 2 if word_embed_type: config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.exp_name += '_word_{}_{}'.format(word_embed_type if word_embed_type else 'random', 'tune' if config.word_embed_trainable else 'fix') config.use_charpos_input = use_charpos_input if config.use_charpos_input: config.charpos_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos')) config.charpos_vocab_size = len(config.charpos_vocab) + 2 if charpos_embed_type: config.charpos_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=charpos_embed_type)) config.charpos_embed_trainable = charpos_embed_trainable config.charpos_embed_dim = config.charpos_embeddings.shape[1] else: config.charpos_embeddings = None config.charpos_embed_trainable = True config.exp_name += '_charpos_{}_{}'.format(charpos_embed_type if charpos_embed_type else 'random', 'tune' if config.charpos_embed_trainable else 'fix') config.use_softword_input = use_softword_input if config.use_softword_input: config.exp_name += '_softword' config.use_dictfeat_input = use_dictfeat_input if config.use_dictfeat_input: config.exp_name += '_dictfeat' config.use_maxmatch_input = use_maxmatch_input if config.use_maxmatch_input: config.exp_name += '_maxmatch' # logger to log output of training process train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': learning_rate, 'other_params': kwargs} print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = RecognitionModel(config, **kwargs) train_data_type, dev_data_type = 'train', 'dev' train_generator = RecognitionDataGenerator(train_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) valid_generator = RecognitionDataGenerator(dev_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) if not os.path.exists(model_save_path) or overwrite: start_time = time.time() model.train(train_generator, valid_generator) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) model.load_best_model() print('Logging Info - Evaluate over valid data:') r, p, f1 = model.evaluate(valid_generator) train_log['dev_performance'] = (r, p, f1) swa_type = None if 'swa' in config.callbacks_to_add: swa_type = 'swa' elif 'swa_clr' in config.callbacks_to_add: swa_type = 'swa_clr' if swa_type: model.load_swa_model(swa_type) print('Logging Info - Evaluate over valid data based on swa model:') r, p, f1 = model.evaluate(valid_generator) train_log['swa_dev_performance'] = (r, p, f1) train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_er'), log=train_log, mode='a') del model gc.collect() K.clear_session()
def recognition(model_name, predict_log, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True, use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1, use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True, use_word_input=False, word_embed_type=None, word_embed_trainable=True, use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True, use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False, callbacks_to_add=None, swa_type=None, predict_on_dev=True, predict_on_final_test=True, **kwargs): config = ModelConfig() config.model_name = model_name config.label_schema = label_schema config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type config.use_char_input = use_char_input if embed_type: config.embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable config.embed_dim = config.embeddings.shape[1] else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] config.vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load( format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) if config.use_char_input: config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format( model_name, config.embed_type if config.embed_type else 'random', 'tune' if config.embed_trainable else 'fix', batch_size, optimizer_type, learning_rate, label_schema) else: config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema) if kwargs: config.exp_name += '_' + '_'.join( [str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str config.use_bert_input = use_bert_input config.bert_type = bert_type config.bert_trainable = bert_trainable config.bert_layer_num = bert_layer_num assert config.use_char_input or config.use_bert_input if config.use_bert_input: config.exp_name += '_{}_layer_{}_{}'.format( bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix') config.use_bichar_input = use_bichar_input if config.use_bichar_input: config.bichar_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar')) config.bichar_vocab_size = len(config.bichar_vocab) + 2 if bichar_embed_type: config.bichar_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=bichar_embed_type)) config.bichar_embed_trainable = bichar_embed_trainable config.bichar_embed_dim = config.bichar_embeddings.shape[1] else: config.bichar_embeddings = None config.bichar_embed_trainable = True config.exp_name += '_bichar_{}_{}'.format( bichar_embed_type if bichar_embed_type else 'random', 'tune' if config.bichar_embed_trainable else 'fix') config.use_word_input = use_word_input if config.use_word_input: config.word_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len(config.word_vocab) + 2 if word_embed_type: config.word_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.exp_name += '_word_{}_{}'.format( word_embed_type if word_embed_type else 'random', 'tune' if config.word_embed_trainable else 'fix') config.use_charpos_input = use_charpos_input if config.use_charpos_input: config.charpos_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos')) config.charpos_vocab_size = len(config.charpos_vocab) + 2 if charpos_embed_type: config.charpos_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=charpos_embed_type)) config.charpos_embed_trainable = charpos_embed_trainable config.charpos_embed_dim = config.charpos_embeddings.shape[1] else: config.charpos_embeddings = None config.charpos_embed_trainable = True config.exp_name += '_charpos_{}_{}'.format( charpos_embed_type if charpos_embed_type else 'random', 'tune' if config.charpos_embed_trainable else 'fix') config.use_softword_input = use_softword_input if config.use_softword_input: config.exp_name += '_softword' config.use_dictfeat_input = use_dictfeat_input if config.use_dictfeat_input: config.exp_name += '_dictfeat' config.use_maxmatch_input = use_maxmatch_input if config.use_maxmatch_input: config.exp_name += '_maxmatch' # logger to log output of training process predict_log.update({ 'er_exp_name': config.exp_name, 'er_batch_size': batch_size, 'er_optimizer': optimizer_type, 'er_epoch': n_epoch, 'er_learning_rate': learning_rate, 'er_other_params': kwargs }) print('Logging Info - Experiment: %s' % config.exp_name) model = RecognitionModel(config, **kwargs) dev_data_type = 'dev' if predict_on_final_test: test_data_type = 'test_final' else: test_data_type = 'test' valid_generator = RecognitionDataGenerator( dev_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) test_generator = RecognitionDataGenerator( test_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if not os.path.exists(model_save_path): raise FileNotFoundError( 'Recognition model not exist: {}'.format(model_save_path)) if swa_type is None: model.load_best_model() elif 'swa' in callbacks_to_add: model.load_swa_model(swa_type) predict_log['er_exp_name'] += '_{}'.format(swa_type) if predict_on_dev: print('Logging Info - Generate submission for valid data:') dev_pred_mentions = model.predict(valid_generator) else: dev_pred_mentions = None print('Logging Info - Generate submission for test data:') test_pred_mentions = model.predict(test_generator) return dev_pred_mentions, test_pred_mentions
def prepare_config(model_type='bert-base-uncased', input_type='name_desc', use_multi_task=True, use_harl=False, use_hal=False, cate_embed_dim=100, use_word_input=False, word_embed_type='w2v', word_embed_trainable=True, word_embed_dim=300, use_bert_input=True, bert_trainable=True, use_bert_type='pooler', n_last_hidden_layer=0, dense_after_bert=True, use_pair_input=True, max_len=None, share_father_pred='no', use_mask_for_cate2=False, use_mask_for_cate3=True, cate3_mask_type='cate1', cate1_loss_weight=1., cate2_loss_weight=1., cate3_loss_weight=1., batch_size=32, predict_batch_size=32, n_epoch=50, learning_rate=2e-5, optimizer='adam', use_focal_loss=False, callbacks_to_add=None, swa_start=15, early_stopping_patience=5, max_lr=6e-5, min_lr=1e-5, train_on_cv=False, cv_random_state=42, cv_fold=5, cv_index=0, exchange_pair=False, exchange_threshold=0.1, use_pseudo_label=False, pseudo_path=None, pseudo_random_state=42, pseudo_rate=0.1, pseudo_index=0, pseudo_name=None, exp_name=None): config = ModelConfig() config.model_type = model_type config.input_type = input_type config.use_multi_task = use_multi_task config.use_harl = use_harl config.use_hal = use_hal assert not (config.use_harl and config.use_hal) config.cate_embed_dim = cate_embed_dim config.use_word_input = use_word_input config.word_embed_type = word_embed_type if config.use_word_input: if word_embed_type: config.word_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.word_embed_dim = word_embed_dim config.word_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len( config.word_vocab) + 2 # 0: mask, 1: padding else: config.word_vocab = None config.use_bert_input = use_bert_input config.bert_trainable = bert_trainable if config.use_bert_input: config.use_bert_type = use_bert_type config.dense_after_bert = dense_after_bert if config.use_bert_type in ['hidden', 'hidden_pooler'] or \ (config.use_multi_task and (config.use_harl or config.use_hal)): config.output_hidden_state = True config.n_last_hidden_layer = n_last_hidden_layer else: config.output_hidden_state = False config.n_last_hidden_layer = 0 if config.input_type == 'name_desc': config.use_pair_input = use_pair_input else: config.use_pair_input = False if config.use_bert_input and max_len is None: config.max_len = MAX_LEN[input_type] else: config.max_len = max_len config.cate1_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate1')) config.cate2_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate2')) config.cate3_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate3')) config.all_cate_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='all_cate')) config.idx2cate1 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1')) config.idx2cate2 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2')) config.idx2cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3')) config.idx2all_cate = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='all_cate')) config.cate1_to_cate2 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT)) config.cate2_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) config.cate1_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) config.cate1_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT)) config.cate2_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT)) config.cate3_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT)) config.n_cate1 = len(config.cate1_vocab) config.n_cate2 = len(config.cate2_vocab) config.n_cate3 = len(config.cate3_vocab) config.n_all_cate = len(config.all_cate_vocab) if config.use_multi_task and (config.use_harl or config.use_hal): config.share_father_pred = 'no' config.use_mask_for_cate2 = False config.use_mask_for_cate3 = False config.cate3_mask_type = None else: config.share_father_pred = share_father_pred config.use_mask_for_cate2 = use_mask_for_cate2 config.use_mask_for_cate3 = use_mask_for_cate3 config.cate3_mask_type = cate3_mask_type # if config.use_mask_for_cate2: if config.use_mask_for_cate3: if config.cate3_mask_type == 'cate1': config.cate_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) elif config.cate3_mask_type == 'cate2': config.cate_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) config.cate1_loss_weight = cate1_loss_weight config.cate2_loss_weight = cate2_loss_weight config.cate3_loss_weight = cate3_loss_weight config.batch_size = batch_size config.predict_batch_size = predict_batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = optimizer config.learning_rate = learning_rate config.use_focal_loss = use_focal_loss config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] if 'swa' in config.callbacks_to_add: config.swa_start = swa_start config.early_stopping_patience = early_stopping_patience for lr_scheduler in [ 'clr', 'sgdr', 'clr_1', 'clr_2', 'warm_up', 'swa_clr' ]: if lr_scheduler in config.callbacks_to_add: config.max_lr = max_lr config.min_lr = min_lr config.train_on_cv = train_on_cv if config.train_on_cv: config.cv_random_state = cv_random_state config.cv_fold = cv_fold config.cv_index = cv_index config.exchange_pair = exchange_pair if config.exchange_pair: config.exchange_threshold = exchange_threshold config.use_pseudo_label = use_pseudo_label if config.use_pseudo_label: config.pseudo_path = pseudo_path config.pseudo_random_state = pseudo_random_state config.pseudo_rate = pseudo_rate config.pseudo_index = pseudo_index # build experiment name from parameter configuration config.exp_name = f'{config.model_type}_{config.input_type}' if config.use_pair_input: config.exp_name += '_pair' config.exp_name += f'_len_{config.max_len}' if config.use_word_input: config.exp_name += f"_word_{config.word_embed_type}_{'tune' if config.word_embed_trainable else 'fix'}" if config.use_bert_input: config.exp_name += f"_bert_{config.use_bert_type}_{'tune' if config.bert_trainable else 'fix'}" if config.output_hidden_state: config.exp_name += f'_hid_{config.n_last_hidden_layer}' if config.dense_after_bert: config.exp_name += '_dense' if config.use_multi_task: if config.use_harl: config.exp_name += f'_harl_{config.cate_embed_dim}' elif config.use_hal: config.exp_name += f'_hal_{config.cate_embed_dim}' config.exp_name += f'_{config.cate1_loss_weight}_{config.cate2_loss_weight}_{config.cate3_loss_weight}' else: config.exp_name += f'_not_multi_task' if config.share_father_pred in ['after', 'before']: config.exp_name += f'_{config.share_father_pred}' if config.use_mask_for_cate2: config.exp_name += f'_mask_cate2' if config.use_mask_for_cate3: config.exp_name += f'_mask_cate3_with_{config.cate3_mask_type}' if config.use_focal_loss: config.exp_name += f'_focal' config.exp_name += f'_{config.optimizer}_{config.learning_rate}_{config.batch_size}_{config.n_epoch}' callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str if config.train_on_cv: config.exp_name += f'_{config.cv_random_state}_{config.cv_fold}_{config.cv_index}' if config.exchange_pair: config.exp_name += f"_ex_pair_{config.exchange_threshold}" if config.use_pseudo_label: if pseudo_name: config.exp_name += f"_{pseudo_name}_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" elif 'dev' in config.pseudo_path: config.exp_name += f"_dev_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" else: config.exp_name += f"_test_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" if exp_name: config.exp_name = exp_name return config