def recognition(model_name, predict_log, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True, use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1, use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True, use_word_input=False, word_embed_type=None, word_embed_trainable=True, use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True, use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False, callbacks_to_add=None, swa_type=None, predict_on_dev=True, predict_on_final_test=True, **kwargs): config = ModelConfig() config.model_name = model_name config.label_schema = label_schema config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type config.use_char_input = use_char_input if embed_type: config.embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable config.embed_dim = config.embeddings.shape[1] else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] config.vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load( format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) if config.use_char_input: config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format( model_name, config.embed_type if config.embed_type else 'random', 'tune' if config.embed_trainable else 'fix', batch_size, optimizer_type, learning_rate, label_schema) else: config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema) if kwargs: config.exp_name += '_' + '_'.join( [str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str config.use_bert_input = use_bert_input config.bert_type = bert_type config.bert_trainable = bert_trainable config.bert_layer_num = bert_layer_num assert config.use_char_input or config.use_bert_input if config.use_bert_input: config.exp_name += '_{}_layer_{}_{}'.format( bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix') config.use_bichar_input = use_bichar_input if config.use_bichar_input: config.bichar_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar')) config.bichar_vocab_size = len(config.bichar_vocab) + 2 if bichar_embed_type: config.bichar_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=bichar_embed_type)) config.bichar_embed_trainable = bichar_embed_trainable config.bichar_embed_dim = config.bichar_embeddings.shape[1] else: config.bichar_embeddings = None config.bichar_embed_trainable = True config.exp_name += '_bichar_{}_{}'.format( bichar_embed_type if bichar_embed_type else 'random', 'tune' if config.bichar_embed_trainable else 'fix') config.use_word_input = use_word_input if config.use_word_input: config.word_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len(config.word_vocab) + 2 if word_embed_type: config.word_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.exp_name += '_word_{}_{}'.format( word_embed_type if word_embed_type else 'random', 'tune' if config.word_embed_trainable else 'fix') config.use_charpos_input = use_charpos_input if config.use_charpos_input: config.charpos_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos')) config.charpos_vocab_size = len(config.charpos_vocab) + 2 if charpos_embed_type: config.charpos_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=charpos_embed_type)) config.charpos_embed_trainable = charpos_embed_trainable config.charpos_embed_dim = config.charpos_embeddings.shape[1] else: config.charpos_embeddings = None config.charpos_embed_trainable = True config.exp_name += '_charpos_{}_{}'.format( charpos_embed_type if charpos_embed_type else 'random', 'tune' if config.charpos_embed_trainable else 'fix') config.use_softword_input = use_softword_input if config.use_softword_input: config.exp_name += '_softword' config.use_dictfeat_input = use_dictfeat_input if config.use_dictfeat_input: config.exp_name += '_dictfeat' config.use_maxmatch_input = use_maxmatch_input if config.use_maxmatch_input: config.exp_name += '_maxmatch' # logger to log output of training process predict_log.update({ 'er_exp_name': config.exp_name, 'er_batch_size': batch_size, 'er_optimizer': optimizer_type, 'er_epoch': n_epoch, 'er_learning_rate': learning_rate, 'er_other_params': kwargs }) print('Logging Info - Experiment: %s' % config.exp_name) model = RecognitionModel(config, **kwargs) dev_data_type = 'dev' if predict_on_final_test: test_data_type = 'test_final' else: test_data_type = 'test' valid_generator = RecognitionDataGenerator( dev_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) test_generator = RecognitionDataGenerator( test_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if not os.path.exists(model_save_path): raise FileNotFoundError( 'Recognition model not exist: {}'.format(model_save_path)) if swa_type is None: model.load_best_model() elif 'swa' in callbacks_to_add: model.load_swa_model(swa_type) predict_log['er_exp_name'] += '_{}'.format(swa_type) if predict_on_dev: print('Logging Info - Generate submission for valid data:') dev_pred_mentions = model.predict(valid_generator) else: dev_pred_mentions = None print('Logging Info - Generate submission for test data:') test_pred_mentions = model.predict(test_generator) return dev_pred_mentions, test_pred_mentions
def train_model(genre, input_level, word_embed_type, word_embed_trainable, batch_size, learning_rate, optimizer_type, model_name, n_epoch=50, add_features=False, scale_features=False, overwrite=False, lr_range_test=False, callbacks_to_add=None, eval_on_train=False, **kwargs): config = ModelConfig() config.genre = genre config.input_level = input_level config.max_len = config.word_max_len[genre] if input_level == 'word' else config.char_max_len[genre] config.word_embed_type = word_embed_type config.word_embed_trainable = word_embed_trainable config.callbacks_to_add = callbacks_to_add or [] config.add_features = add_features config.batch_size = batch_size config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.n_epoch = n_epoch config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, genre, word_embed_type)) vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, genre, input_level)) config.idx2token = dict((idx, token) for token, idx in vocab.items()) # experiment name configuration config.exp_name = '{}_{}_{}_{}_{}_{}_{}_{}'.format(genre, model_name, input_level, word_embed_type, 'tune' if word_embed_trainable else 'fix', batch_size, '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()]), optimizer_type) if config.add_features: config.exp_name = config.exp_name + '_feature_scaled' if scale_features else config.exp_name + '_featured' if len(config.callbacks_to_add) > 0: callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str input_config = kwargs['input_config'] if 'input_config' in kwargs else 'token' # input default is word embedding if input_config in ['cache_elmo', 'token_combine_cache_elmo']: # get elmo embedding based on cache, we first get a ELMoCache instance if 'elmo_model_type' in kwargs: elmo_model_type = kwargs['elmo_model_type'] kwargs.pop('elmo_model_type') # we don't need it in kwargs any more else: elmo_model_type = 'allennlp' if 'elmo_output_mode' in kwargs: elmo_output_mode = kwargs['elmo_output_mode'] kwargs.pop('elmo_output_mode') # we don't need it in kwargs any more else: elmo_output_mode ='elmo' elmo_cache = ELMoCache(options_file=config.elmo_options_file, weight_file=config.elmo_weight_file, cache_dir=config.cache_dir, idx2token=config.idx2token, max_sentence_length=config.max_len, elmo_model_type=elmo_model_type, elmo_output_mode=elmo_output_mode) elif input_config in ['elmo_id', 'elmo_s', 'token_combine_elmo_id', 'token_combine_elmo_s']: # get elmo embedding using tensorflow_hub, we must provide a tfhub_url kwargs['elmo_model_url'] = config.elmo_model_url # logger to log output of training process train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': learning_rate, 'other_params': kwargs} print('Logging Info - Experiment: %s' % config.exp_name) if model_name == 'KerasInfersent': model = KerasInfersentModel(config, **kwargs) elif model_name == 'KerasEsim': model = KerasEsimModel(config, **kwargs) elif model_name == 'KerasDecomposable': model = KerasDecomposableAttentionModel(config, **kwargs) elif model_name == 'KerasSiameseBiLSTM': model = KerasSimaeseBiLSTMModel(config, **kwargs) elif model_name == 'KerasSiameseCNN': model = KerasSiameseCNNModel(config, **kwargs) elif model_name == 'KerasIACNN': model = KerasIACNNModel(config, **kwargs) elif model_name == 'KerasSiameseLSTMCNNModel': model = KerasSiameseLSTMCNNModel(config, **kwargs) elif model_name == 'KerasRefinedSSAModel': model = KerasRefinedSSAModel(config, **kwargs) else: raise ValueError('Model Name Not Understood : {}'.format(model_name)) # model.summary() train_input, dev_input, test_input = None, None, None if lr_range_test: # conduct lr range test to find optimal learning rate (not train model) train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features) dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features) model.lr_range_test(x_train=train_input['x'], y_train=train_input['y'], x_valid=dev_input['x'], y_valid=dev_input['y']) return model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if not os.path.exists(model_save_path) or overwrite: start_time = time.time() if input_config in ['cache_elmo', 'token_combine_cache_elmo']: train_input = ELMoGenerator(genre, input_level, 'train', config.batch_size, elmo_cache, return_data=(input_config == 'token_combine_cache_elmo'), return_features=config.add_features) dev_input = ELMoGenerator(genre, input_level, 'dev', config.batch_size, elmo_cache, return_data=(input_config == 'token_combine_cache_elmo'), return_features=config.add_features) model.train_with_generator(train_input, dev_input) else: train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features) dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features) model.train(x_train=train_input['x'], y_train=train_input['y'], x_valid=dev_input['x'], y_valid=dev_input['y']) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) def eval_on_data(eval_with_generator, input_data, data_type): model.load_best_model() if eval_with_generator: acc = model.evaluate_with_generator(generator=input_data, y=input_data.input_label) else: acc = model.evaluate(x=input_data['x'], y=input_data['y']) train_log['%s_acc' % data_type] = acc swa_type = None if 'swa' in config.callbacks_to_add: swa_type = 'swa' elif 'swa_clr' in config.callbacks_to_add: swa_type = 'swa_clr' if swa_type: print('Logging Info - %s Model' % swa_type) model.load_swa_model(swa_type=swa_type) swa_acc = model.evaluate(x=input_data['x'], y=input_data['y']) train_log['%s_%s_acc' % (swa_type, data_type)] = swa_acc ensemble_type = None if 'sse' in config.callbacks_to_add: ensemble_type = 'sse' elif 'fge' in config.callbacks_to_add: ensemble_type = 'fge' if ensemble_type: print('Logging Info - %s Ensemble Model' % ensemble_type) ensemble_predict = {} for model_file in os.listdir(config.checkpoint_dir): if model_file.startswith(config.exp_name+'_%s' % ensemble_type): match = re.match(r'(%s_%s_)([\d+])(.hdf5)' % (config.exp_name, ensemble_type), model_file) model_id = int(match.group(2)) model_path = os.path.join(config.checkpoint_dir, model_file) print('Logging Info: Loading {} ensemble model checkpoint: {}'.format(ensemble_type, model_file)) model.load_model(model_path) ensemble_predict[model_id] = model.predict(x=input_data['x']) ''' we expect the models saved towards the end of run may have better performance than models saved earlier in the run, we sort the models so that the older models ('s id) are first. ''' sorted_ensemble_predict = sorted(ensemble_predict.items(), key=lambda x: x[0], reverse=True) model_predicts = [] for model_id, model_predict in sorted_ensemble_predict: single_acc = eval_acc(model_predict, input_data['y']) print('Logging Info - %s_single_%d_%s Acc : %f' % (ensemble_type, model_id, data_type, single_acc)) train_log['%s_single_%d_%s_acc' % (ensemble_type, model_id, data_type)] = single_acc model_predicts.append(model_predict) ensemble_acc = eval_acc(np.mean(np.array(model_predicts), axis=0), input_data['y']) print('Logging Info - %s_ensemble_%d_%s Acc : %f' % (ensemble_type, model_id, data_type, ensemble_acc)) train_log['%s_ensemble_%d_%s_acc' % (ensemble_type, model_id, data_type)] = ensemble_acc if eval_on_train: # might take a long time print('Logging Info - Evaluate over train data:') if input_config in ['cache_elmo', 'token_combine_cache_elmo']: train_input = ELMoGenerator(genre, input_level, 'train', config.batch_size, elmo_cache, return_data=(input_config == 'token_combine_cache_elmo'), return_features=config.add_features, return_label=False) eval_on_data(eval_with_generator=True, input_data=train_input, data_type='train') else: train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features) eval_on_data(eval_with_generator=False, input_data=train_input, data_type='train') print('Logging Info - Evaluate over valid data:') if input_config in ['cache_elmo', 'token_combine_cache_elmo']: dev_input = ELMoGenerator(genre, input_level, 'dev', config.batch_size, elmo_cache, return_data=(input_config == 'token_combine_cache_elmo'), return_features=config.add_features, return_label=False) eval_on_data(eval_with_generator=True, input_data=dev_input, data_type='dev') else: if dev_input is None: dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features) eval_on_data(eval_with_generator=False, input_data=dev_input, data_type='dev') print('Logging Info - Evaluate over test data:') if input_config in ['cache_elmo', 'token_combine_cache_elmo']: test_input = ELMoGenerator(genre, input_level, 'test', config.batch_size, elmo_cache, return_data=(input_config == 'token_combine_cache_elmo'), return_features=config.add_features, return_label=False) eval_on_data(eval_with_generator=True, input_data=test_input, data_type='test') else: if test_input is None: test_input = load_input_data(genre, input_level, 'test', input_config, config.add_features, scale_features) eval_on_data(eval_with_generator=False, input_data=test_input, data_type='test') train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, genre), log=train_log, mode='a') return train_log
def train(dataset, char_embed_type, char_trainable, fw_embed_type, fw_trainable, bw_embed_type, bw_trainable, gaze_embed_dim, n_step, n_layer, rnn_units, dropout, batch_size, n_epoch, optimizer, callbacks_to_add=None, overwrite=False): config = ModelConfig() config.char_embedding = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, dataset=dataset, type=char_embed_type)) config.char_trainable = char_trainable config.fw_bigram_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, dataset=dataset, type=fw_embed_type)) config.fw_bigram_trainable = fw_trainable config.bw_bigram_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, dataset=dataset, type=bw_embed_type)) config.bw_bigram_trainable = bw_trainable config.char_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, dataset=dataset, level='char')) config.fw_bigram_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, dataset=dataset, level='fw_bigram')) config.bw_bigram_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, dataset=dataset, level='bw_bigram')) config.tag_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, dataset=dataset, level='tag')) config.idx2tag = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, dataset=dataset, level='tag')) config.gaze_embed_dim = gaze_embed_dim config.n_step = n_step config.n_layer = n_layer config.rnn_units = rnn_units config.dropout = dropout config.batch_size = batch_size config.n_epoch = n_epoch config.optimizer = optimizer config.callbacks_to_add = callbacks_to_add config.model_name = 'ggnn' config.exp_name = 'ggnn_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( char_embed_type, 'tune' if char_trainable else 'fix', fw_embed_type, 'tune' if fw_trainable else 'fix', bw_embed_type, 'tune' if bw_embed_type else 'fix', gaze_embed_dim, n_step, n_layer, rnn_units, batch_size, n_epoch) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str dev_generator = NERGenerator('dev', dataset, batch_size, config.char_vocab, config.fw_bigram_vocab, config.bw_bigram_vocab, config.tag_vocab) dev_input, _ = dev_generator.prepare_input(range(dev_generator.data_size)) dev_tags = [text_example.tags for text_example in dev_generator.data] test_generator = NERGenerator('test', dataset, batch_size, config.char_vocab, config.fw_bigram_vocab, config.bw_bigram_vocab, config.tag_vocab) test_input, _ = test_generator.prepare_input( range(test_generator.data_size)) test_tags = [text_example.tags for text_example in test_generator.data] config.n_gaze = dev_generator.n_gaze # logger to log output of training process train_log = {'exp_name': config.exp_name} print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = MultiDigraph(config) if not os.path.exists(model_save_path) or overwrite: train_generator = NERGenerator('train', dataset, batch_size, config.char_vocab, config.fw_bigram_vocab, config.bw_bigram_vocab, config.tag_vocab) start_time = time.time() model.fit_generator(train_generator, dev_generator) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) model.load_best_model() print('Logging Info - Evaluate over valid data:') dev_score = model.evaluate(dev_input, dev_tags) train_log['dev_performance'] = dev_score print('Logging Info - Evaluate over test data:') test_score = model.evaluate(test_input, test_tags) train_log['test_performance'] = test_score if 'swa' in config.callbacks_to_add: model.load_swa_model() print('Logging Info - Evaluate over valid data based on swa model:') swa_dev_score = model.evaluate(dev_input, dev_tags) train_log['swa_dev_performance'] = swa_dev_score print('Logging Info - Evaluate over test data based on swa model:') swa_test_score = model.evaluate(test_input, test_tags) train_log['swa_test_performance'] = swa_test_score train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG_TEMPLATE, dataset=dataset), log=train_log, mode='a') del model gc.collect() K.clear_session()
def link(model_name, dev_pred_mentions, test_pred_mentions, predict_log, batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', embed_type=None, embed_trainable=True, use_relative_pos=False, n_neg=1, omit_one_cand=True, callbacks_to_add=None, swa_type=None, predict_on_final_test=True, **kwargs): config = ModelConfig() config.model_name = model_name config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type if embed_type: config.embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] config.vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load( format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) config.entity_desc = pickle_load( format_filename(PROCESSED_DATA_DIR, ENTITY_DESC_FILENAME)) config.exp_name = '{}_{}_{}_{}_{}_{}'.format( model_name, embed_type if embed_type else 'random', 'tune' if embed_trainable else 'fix', batch_size, optimizer_type, learning_rate) config.use_relative_pos = use_relative_pos if config.use_relative_pos: config.exp_name += '_rel' config.n_neg = n_neg if config.n_neg > 1: config.exp_name += '_neg_{}'.format(config.n_neg) config.omit_one_cand = omit_one_cand if not config.omit_one_cand: config.exp_name += '_not_omit' if kwargs: config.exp_name += '_' + '_'.join( [str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str # logger to log output of training process predict_log.update({ 'el_exp_name': config.exp_name, 'el_batch_size': batch_size, 'el_optimizer': optimizer_type, 'el_epoch': n_epoch, 'el_learning_rate': learning_rate, 'el_other_params': kwargs }) print('Logging Info - Experiment: %s' % config.exp_name) model = LinkModel(config, **kwargs) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if not os.path.exists(model_save_path): raise FileNotFoundError( 'Recognition model not exist: {}'.format(model_save_path)) if swa_type is None: model.load_best_model() elif 'swa' in callbacks_to_add: model.load_swa_model(swa_type) predict_log['er_exp_name'] += '_{}'.format(swa_type) dev_data_type = 'dev' dev_data = load_data(dev_data_type) dev_text_data, dev_gold_mention_entities = [], [] for data in dev_data: dev_text_data.append(data['text']) dev_gold_mention_entities.append(data['mention_data']) if predict_on_final_test: test_data_type = 'test_final' else: test_data_type = 'test' test_data = load_data(test_data_type) test_text_data = [data['text'] for data in test_data] if dev_pred_mentions is not None: print( 'Logging Info - Evaluate over valid data based on predicted mention:' ) r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions, dev_gold_mention_entities) dev_performance = 'dev_performance' if swa_type is None else '%s_dev_performance' % swa_type predict_log[dev_performance] = (r, p, f1) print('Logging Info - Generate submission for test data:') test_pred_mention_entities = model.predict(test_text_data, test_pred_mentions) test_submit_file = predict_log[ 'er_exp_name'] + '_' + config.exp_name + '_%s%ssubmit.json' % ( swa_type + '_' if swa_type else '', 'final_' if predict_on_final_test else '') submit_result(test_submit_file, test_data, test_pred_mention_entities) predict_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step'), log=predict_log, mode='a') return predict_log
def train(train_d,dev_d,test_d,kfold,dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type, batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=True): config = ModelConfig() config.neighbor_sample_size = neighbor_sample_size config.embed_dim = embed_dim config.n_depth = n_depth config.l2_weight = l2_weight config.dataset=dataset config.K_Fold=kfold config.lr = lr config.optimizer = get_optimizer(optimizer_type, lr) config.batch_size = batch_size config.aggregator_type = aggregator_type config.n_epoch = n_epoch config.callbacks_to_add = callbacks_to_add #drug id #should be SMILES config.drug_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset))) #entity id config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset))) #relation id #string config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset))) #chosen entity matrix config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)) config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)) config.drug_smile = np.load(format_filename(PROCESSED_DATA_DIR, DRUG_SMILE_TEMPLATE),allow_pickle=True) config.smile_hash = np.load(format_filename(PROCESSED_DATA_DIR, SMILE_HASH),allow_pickle=True) config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \ f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \ f'batch_size_{batch_size}_epoch_{n_epoch}' callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')#去掉了这两种方式使用swa得方式平均 config.exp_name += callback_str train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': lr} print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = DDKG(config) #model = KGCN(config) train_data=np.array(train_d) valid_data=np.array(dev_d) test_data=np.array(test_d) if not os.path.exists(model_save_path) or overwrite: start_time = time.time() print([train_data[:, :1], train_data[:, 1:2]]) model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3], x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3]) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) print('Logging Info - Evaluate over valid data:') model.load_best_model() auc, acc, f1,aupr, fpr, tpr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_aupr: {aupr}' ) train_log['dev_auc'] = auc train_log['dev_acc'] = acc train_log['dev_f1'] = f1 train_log['dev_aupr']=aupr train_log['k_fold']=kfold train_log['dataset']=dataset train_log['aggregate_type']=config.aggregator_type train_log['dev_fpr'] = fpr train_log['dev_tpr'] = tpr if 'swa' in config.callbacks_to_add: model.load_swa_model() print('Logging Info - Evaluate over valid data based on swa model:') auc, acc, f1,aupr, fpr, tpr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) train_log['swa_dev_auc'] = auc train_log['swa_dev_acc'] = acc train_log['swa_dev_f1'] = f1 train_log['swa_dev_aupr']=aupr train_log['swa_dev_fpr'] = fpr train_log['swa_dev_tpr'] = tpr print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, swa_dev_aupr: {aupr}') #修改输出指标 print('Logging Info - Evaluate over test data:') model.load_best_model() auc, acc, f1, aupr,fpr, tpr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) train_log['test_auc'] = auc train_log['test_acc'] = acc train_log['test_f1'] = f1 train_log['test_aupr'] =aupr train_log['test_fpr'] = fpr train_log['test_tpr'] = tpr print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_aupr: {aupr}, test_fpr: {fpr}', ) if 'swa' in config.callbacks_to_add: model.load_swa_model() print('Logging Info - Evaluate over test data based on swa model:') auc, acc, f1,aupr, fpr, tpr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) train_log['swa_test_auc'] = auc train_log['swa_test_acc'] = acc train_log['swa_test_f1'] = f1 train_log['swa_test_aupr'] = aupr train_log['swa_test_fpr'] = fpr train_log['swa_test_tpr'] = tpr print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, swa_test_aupr: {aupr}') train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a') del model gc.collect() K.clear_session() return train_log
def train(dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type, batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=False): config = ModelConfig() config.neighbor_sample_size = neighbor_sample_size config.embed_dim = embed_dim config.n_depth = n_depth config.l2_weight = l2_weight config.lr = lr config.optimizer = get_optimizer(optimizer_type, lr) config.batch_size = batch_size config.aggregator_type = aggregator_type config.n_epoch = n_epoch config.callbacks_to_add = callbacks_to_add config.user_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, USER_VOCAB_TEMPLATE, dataset=dataset))) config.item_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, ITEM_VOCAB_TEMPLATE, dataset=dataset))) config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset))) config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset))) config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)) config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)) config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \ f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \ f'batch_size_{batch_size}_epoch_{n_epoch}' callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str # logger to log output of training process train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': lr} print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = KGCN(config) train_data = load_data(dataset, 'train') valid_data = load_data(dataset, 'dev') test_data = load_data(dataset, 'test') if not os.path.exists(model_save_path) or overwrite: start_time = time.time() model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3], x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3]) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) print('Logging Info - Evaluate over valid data:') model.load_best_model() auc, acc, f1 = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) user_list, train_record, valid_record, item_set, k_list = topk_settings(train_data, valid_data, config.item_vocab_size) topk_p, topk_r = topk_eval(model, user_list, train_record, valid_record, item_set, k_list) print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_topk_p: {topk_p}, ' f'dev_topk_r: {topk_r}') train_log['dev_auc'] = auc train_log['dev_acc'] = acc train_log['dev_f1'] = f1 train_log['dev_topk_p'] = topk_p train_log['dev_topk_r'] = topk_r if 'swa' in config.callbacks_to_add: model.load_swa_model() print('Logging Info - Evaluate over valid data based on swa model:') auc, acc, f1 = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) topk_p, topk_r = topk_eval(model, user_list, train_record, valid_record, item_set, k_list) train_log['swa_dev_auc'] = auc train_log['swa_dev_acc'] = acc train_log['swa_dev_f1'] = f1 train_log['swa_dev_topk_p'] = topk_p train_log['swa_dev_topk_r'] = topk_r print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, ' f'swa_dev_topk_p: {topk_p}, swa_dev_topk_r: {topk_r}') print('Logging Info - Evaluate over test data:') model.load_best_model() auc, acc, f1 = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) user_list, train_record, test_record, item_set, k_list = topk_settings(train_data, test_data, config.item_vocab_size) topk_p, topk_r = topk_eval(model, user_list, train_record, test_record, item_set, k_list) train_log['test_auc'] = auc train_log['test_acc'] = acc train_log['test_f1'] = f1 train_log['test_topk_p'] = topk_p train_log['test_topk_r'] = topk_r print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_topk_p: {topk_p}, ' f'test_topk_r: {topk_r}') if 'swa' in config.callbacks_to_add: model.load_swa_model() print('Logging Info - Evaluate over test data based on swa model:') auc, acc, f1 = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) topk_p, topk_r = topk_eval(model, user_list, train_record, test_record, item_set, k_list) train_log['swa_test_auc'] = auc train_log['swa_test_acc'] = acc train_log['swa_test_f1'] = f1 train_log['swa_test_topk_p'] = topk_p train_log['swa_test_topk_r'] = topk_r print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, ' f'swa_test_topk_p: {topk_p}, swa_test_topk_r: {topk_r}') train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a') del model gc.collect() K.clear_session()
def train_recognition(model_name, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True, use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1, use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True, use_word_input=False, word_embed_type=None, word_embed_trainable=True, use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True, use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False, callbacks_to_add=None, overwrite=False, swa_start=3, early_stopping_patience=3, **kwargs): config = ModelConfig() config.model_name = model_name config.label_schema = label_schema config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type config.use_char_input = use_char_input if embed_type: config.embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable config.embed_dim = config.embeddings.shape[1] else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or ['modelcheckpoint', 'earlystopping'] if 'swa' in config.callbacks_to_add: config.swa_start = swa_start config.early_stopping_patience = early_stopping_patience config.vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load(format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) if config.use_char_input: config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(model_name, config.embed_type if config.embed_type else 'random', 'tune' if config.embed_trainable else 'fix', batch_size, optimizer_type, learning_rate, label_schema) else: config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema) if config.n_epoch != 50: config.exp_name += '_{}'.format(config.n_epoch) if kwargs: config.exp_name += '_' + '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str config.use_bert_input = use_bert_input config.bert_type = bert_type config.bert_trainable = bert_trainable config.bert_layer_num = bert_layer_num assert config.use_char_input or config.use_bert_input if config.use_bert_input: config.exp_name += '_{}_layer_{}_{}'.format(bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix') config.use_bichar_input = use_bichar_input if config.use_bichar_input: config.bichar_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar')) config.bichar_vocab_size = len(config.bichar_vocab) + 2 if bichar_embed_type: config.bichar_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=bichar_embed_type)) config.bichar_embed_trainable = bichar_embed_trainable config.bichar_embed_dim = config.bichar_embeddings.shape[1] else: config.bichar_embeddings = None config.bichar_embed_trainable = True config.exp_name += '_bichar_{}_{}'.format(bichar_embed_type if bichar_embed_type else 'random', 'tune' if config.bichar_embed_trainable else 'fix') config.use_word_input = use_word_input if config.use_word_input: config.word_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len(config.word_vocab) + 2 if word_embed_type: config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.exp_name += '_word_{}_{}'.format(word_embed_type if word_embed_type else 'random', 'tune' if config.word_embed_trainable else 'fix') config.use_charpos_input = use_charpos_input if config.use_charpos_input: config.charpos_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos')) config.charpos_vocab_size = len(config.charpos_vocab) + 2 if charpos_embed_type: config.charpos_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=charpos_embed_type)) config.charpos_embed_trainable = charpos_embed_trainable config.charpos_embed_dim = config.charpos_embeddings.shape[1] else: config.charpos_embeddings = None config.charpos_embed_trainable = True config.exp_name += '_charpos_{}_{}'.format(charpos_embed_type if charpos_embed_type else 'random', 'tune' if config.charpos_embed_trainable else 'fix') config.use_softword_input = use_softword_input if config.use_softword_input: config.exp_name += '_softword' config.use_dictfeat_input = use_dictfeat_input if config.use_dictfeat_input: config.exp_name += '_dictfeat' config.use_maxmatch_input = use_maxmatch_input if config.use_maxmatch_input: config.exp_name += '_maxmatch' # logger to log output of training process train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': learning_rate, 'other_params': kwargs} print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = RecognitionModel(config, **kwargs) train_data_type, dev_data_type = 'train', 'dev' train_generator = RecognitionDataGenerator(train_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) valid_generator = RecognitionDataGenerator(dev_data_type, config.batch_size, config.label_schema, config.label_to_one_hot[config.label_schema], config.vocab if config.use_char_input else None, config.bert_vocab_file(config.bert_type) if config.use_bert_input else None, config.bert_seq_len, config.bichar_vocab, config.word_vocab, config.use_word_input, config.charpos_vocab, config.use_softword_input, config.use_dictfeat_input, config.use_maxmatch_input) if not os.path.exists(model_save_path) or overwrite: start_time = time.time() model.train(train_generator, valid_generator) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) model.load_best_model() print('Logging Info - Evaluate over valid data:') r, p, f1 = model.evaluate(valid_generator) train_log['dev_performance'] = (r, p, f1) swa_type = None if 'swa' in config.callbacks_to_add: swa_type = 'swa' elif 'swa_clr' in config.callbacks_to_add: swa_type = 'swa_clr' if swa_type: model.load_swa_model(swa_type) print('Logging Info - Evaluate over valid data based on swa model:') r, p, f1 = model.evaluate(valid_generator) train_log['swa_dev_performance'] = (r, p, f1) train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_er'), log=train_log, mode='a') del model gc.collect() K.clear_session()
def prepare_config(model_type='bert-base-uncased', input_type='name_desc', use_multi_task=True, use_harl=False, use_hal=False, cate_embed_dim=100, use_word_input=False, word_embed_type='w2v', word_embed_trainable=True, word_embed_dim=300, use_bert_input=True, bert_trainable=True, use_bert_type='pooler', n_last_hidden_layer=0, dense_after_bert=True, use_pair_input=True, max_len=None, share_father_pred='no', use_mask_for_cate2=False, use_mask_for_cate3=True, cate3_mask_type='cate1', cate1_loss_weight=1., cate2_loss_weight=1., cate3_loss_weight=1., batch_size=32, predict_batch_size=32, n_epoch=50, learning_rate=2e-5, optimizer='adam', use_focal_loss=False, callbacks_to_add=None, swa_start=15, early_stopping_patience=5, max_lr=6e-5, min_lr=1e-5, train_on_cv=False, cv_random_state=42, cv_fold=5, cv_index=0, exchange_pair=False, exchange_threshold=0.1, use_pseudo_label=False, pseudo_path=None, pseudo_random_state=42, pseudo_rate=0.1, pseudo_index=0, pseudo_name=None, exp_name=None): config = ModelConfig() config.model_type = model_type config.input_type = input_type config.use_multi_task = use_multi_task config.use_harl = use_harl config.use_hal = use_hal assert not (config.use_harl and config.use_hal) config.cate_embed_dim = cate_embed_dim config.use_word_input = use_word_input config.word_embed_type = word_embed_type if config.use_word_input: if word_embed_type: config.word_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=word_embed_type)) config.word_embed_trainable = word_embed_trainable config.word_embed_dim = config.word_embeddings.shape[1] else: config.word_embeddings = None config.word_embed_trainable = True config.word_embed_dim = word_embed_dim config.word_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) config.word_vocab_size = len( config.word_vocab) + 2 # 0: mask, 1: padding else: config.word_vocab = None config.use_bert_input = use_bert_input config.bert_trainable = bert_trainable if config.use_bert_input: config.use_bert_type = use_bert_type config.dense_after_bert = dense_after_bert if config.use_bert_type in ['hidden', 'hidden_pooler'] or \ (config.use_multi_task and (config.use_harl or config.use_hal)): config.output_hidden_state = True config.n_last_hidden_layer = n_last_hidden_layer else: config.output_hidden_state = False config.n_last_hidden_layer = 0 if config.input_type == 'name_desc': config.use_pair_input = use_pair_input else: config.use_pair_input = False if config.use_bert_input and max_len is None: config.max_len = MAX_LEN[input_type] else: config.max_len = max_len config.cate1_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate1')) config.cate2_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate2')) config.cate3_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate3')) config.all_cate_vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='all_cate')) config.idx2cate1 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1')) config.idx2cate2 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2')) config.idx2cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3')) config.idx2all_cate = pickle_load( format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='all_cate')) config.cate1_to_cate2 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT)) config.cate2_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) config.cate1_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) config.cate1_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT)) config.cate2_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT)) config.cate3_count_dict = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT)) config.n_cate1 = len(config.cate1_vocab) config.n_cate2 = len(config.cate2_vocab) config.n_cate3 = len(config.cate3_vocab) config.n_all_cate = len(config.all_cate_vocab) if config.use_multi_task and (config.use_harl or config.use_hal): config.share_father_pred = 'no' config.use_mask_for_cate2 = False config.use_mask_for_cate3 = False config.cate3_mask_type = None else: config.share_father_pred = share_father_pred config.use_mask_for_cate2 = use_mask_for_cate2 config.use_mask_for_cate3 = use_mask_for_cate3 config.cate3_mask_type = cate3_mask_type # if config.use_mask_for_cate2: if config.use_mask_for_cate3: if config.cate3_mask_type == 'cate1': config.cate_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) elif config.cate3_mask_type == 'cate2': config.cate_to_cate3 = pickle_load( format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) config.cate1_loss_weight = cate1_loss_weight config.cate2_loss_weight = cate2_loss_weight config.cate3_loss_weight = cate3_loss_weight config.batch_size = batch_size config.predict_batch_size = predict_batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = optimizer config.learning_rate = learning_rate config.use_focal_loss = use_focal_loss config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] if 'swa' in config.callbacks_to_add: config.swa_start = swa_start config.early_stopping_patience = early_stopping_patience for lr_scheduler in [ 'clr', 'sgdr', 'clr_1', 'clr_2', 'warm_up', 'swa_clr' ]: if lr_scheduler in config.callbacks_to_add: config.max_lr = max_lr config.min_lr = min_lr config.train_on_cv = train_on_cv if config.train_on_cv: config.cv_random_state = cv_random_state config.cv_fold = cv_fold config.cv_index = cv_index config.exchange_pair = exchange_pair if config.exchange_pair: config.exchange_threshold = exchange_threshold config.use_pseudo_label = use_pseudo_label if config.use_pseudo_label: config.pseudo_path = pseudo_path config.pseudo_random_state = pseudo_random_state config.pseudo_rate = pseudo_rate config.pseudo_index = pseudo_index # build experiment name from parameter configuration config.exp_name = f'{config.model_type}_{config.input_type}' if config.use_pair_input: config.exp_name += '_pair' config.exp_name += f'_len_{config.max_len}' if config.use_word_input: config.exp_name += f"_word_{config.word_embed_type}_{'tune' if config.word_embed_trainable else 'fix'}" if config.use_bert_input: config.exp_name += f"_bert_{config.use_bert_type}_{'tune' if config.bert_trainable else 'fix'}" if config.output_hidden_state: config.exp_name += f'_hid_{config.n_last_hidden_layer}' if config.dense_after_bert: config.exp_name += '_dense' if config.use_multi_task: if config.use_harl: config.exp_name += f'_harl_{config.cate_embed_dim}' elif config.use_hal: config.exp_name += f'_hal_{config.cate_embed_dim}' config.exp_name += f'_{config.cate1_loss_weight}_{config.cate2_loss_weight}_{config.cate3_loss_weight}' else: config.exp_name += f'_not_multi_task' if config.share_father_pred in ['after', 'before']: config.exp_name += f'_{config.share_father_pred}' if config.use_mask_for_cate2: config.exp_name += f'_mask_cate2' if config.use_mask_for_cate3: config.exp_name += f'_mask_cate3_with_{config.cate3_mask_type}' if config.use_focal_loss: config.exp_name += f'_focal' config.exp_name += f'_{config.optimizer}_{config.learning_rate}_{config.batch_size}_{config.n_epoch}' callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str if config.train_on_cv: config.exp_name += f'_{config.cv_random_state}_{config.cv_fold}_{config.cv_index}' if config.exchange_pair: config.exp_name += f"_ex_pair_{config.exchange_threshold}" if config.use_pseudo_label: if pseudo_name: config.exp_name += f"_{pseudo_name}_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" elif 'dev' in config.pseudo_path: config.exp_name += f"_dev_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" else: config.exp_name += f"_test_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" if exp_name: config.exp_name = exp_name return config
def train_link(model_name, batch_size=32, n_epoch=50, learning_rate=0.001, optimizer_type='adam', embed_type=None, embed_trainable=True, callbacks_to_add=None, use_relative_pos=False, n_neg=1, omit_one_cand=True, overwrite=False, swa_start=5, early_stopping_patience=3, **kwargs): config = ModelConfig() config.model_name = model_name config.batch_size = batch_size config.n_epoch = n_epoch config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.embed_type = embed_type if embed_type: config.embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type)) config.embed_trainable = embed_trainable else: config.embeddings = None config.embed_trainable = True config.callbacks_to_add = callbacks_to_add or [ 'modelcheckpoint', 'earlystopping' ] if 'swa' in config.callbacks_to_add: config.swa_start = swa_start config.early_stopping_patience = early_stopping_patience config.vocab = pickle_load( format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char')) config.vocab_size = len(config.vocab) + 2 config.mention_to_entity = pickle_load( format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME)) config.entity_desc = pickle_load( format_filename(PROCESSED_DATA_DIR, ENTITY_DESC_FILENAME)) config.exp_name = '{}_{}_{}_{}_{}_{}'.format( model_name, embed_type if embed_type else 'random', 'tune' if config.embed_trainable else 'fix', batch_size, optimizer_type, learning_rate) config.use_relative_pos = use_relative_pos if config.use_relative_pos: config.exp_name += '_rel' config.n_neg = n_neg if config.n_neg > 1: config.exp_name += '_neg_{}'.format(config.n_neg) config.omit_one_cand = omit_one_cand if not config.omit_one_cand: config.exp_name += '_not_omit' if kwargs: config.exp_name += '_' + '_'.join( [str(k) + '_' + str(v) for k, v in kwargs.items()]) callback_str = '_' + '_'.join(config.callbacks_to_add) callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') config.exp_name += callback_str # logger to log output of training process train_log = { 'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch, 'learning_rate': learning_rate, 'other_params': kwargs } print('Logging Info - Experiment: %s' % config.exp_name) model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) model = LinkModel(config, **kwargs) train_data_type, dev_data_type = 'train', 'dev' train_generator = LinkDataGenerator( train_data_type, config.vocab, config.mention_to_entity, config.entity_desc, config.batch_size, config.max_desc_len, config.max_erl_len, config.use_relative_pos, config.n_neg, config.omit_one_cand) dev_data = load_data(dev_data_type) if not os.path.exists(model_save_path) or overwrite: start_time = time.time() model.train(train_generator, dev_data) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) model.load_best_model() dev_text_data, dev_pred_mentions, dev_gold_mention_entities = [], [], [] for data in dev_data: dev_text_data.append(data['text']) dev_pred_mentions.append(data['mention_data']) dev_gold_mention_entities.append(data['mention_data']) print('Logging Info - Evaluate over valid data:') r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions, dev_gold_mention_entities) train_log['dev_performance'] = (r, p, f1) swa_type = None if 'swa' in config.callbacks_to_add: swa_type = 'swa' elif 'swa_clr' in config.callbacks_to_add: swa_type = 'swa_clr' if swa_type: model.load_swa_model(swa_type) print('Logging Info - Evaluate over valid data based on swa model:') r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions, dev_gold_mention_entities) train_log['swa_dev_performance'] = (r, p, f1) train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_el'), log=train_log, mode='a') del model gc.collect() K.clear_session()