Exemple #1
0
def recognition(model_name,
                predict_log,
                label_schema='BIOES',
                batch_size=32,
                n_epoch=50,
                learning_rate=0.001,
                optimizer_type='adam',
                use_char_input=True,
                embed_type=None,
                embed_trainable=True,
                use_bert_input=False,
                bert_type='bert',
                bert_trainable=True,
                bert_layer_num=1,
                use_bichar_input=False,
                bichar_embed_type=None,
                bichar_embed_trainable=True,
                use_word_input=False,
                word_embed_type=None,
                word_embed_trainable=True,
                use_charpos_input=False,
                charpos_embed_type=None,
                charpos_embed_trainable=True,
                use_softword_input=False,
                use_dictfeat_input=False,
                use_maxmatch_input=False,
                callbacks_to_add=None,
                swa_type=None,
                predict_on_dev=True,
                predict_on_final_test=True,
                **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.label_schema = label_schema
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    config.use_char_input = use_char_input
    if embed_type:
        config.embeddings = np.load(
            format_filename(PROCESSED_DATA_DIR,
                            EMBEDDING_MATRIX_TEMPLATE,
                            type=embed_type))
        config.embed_trainable = embed_trainable
        config.embed_dim = config.embeddings.shape[1]
    else:
        config.embeddings = None
        config.embed_trainable = True
    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]

    config.vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(
        format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))

    if config.use_char_input:
        config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(
            model_name, config.embed_type if config.embed_type else 'random',
            'tune' if config.embed_trainable else 'fix', batch_size,
            optimizer_type, learning_rate, label_schema)
    else:
        config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size,
                                                  optimizer_type,
                                                  learning_rate, label_schema)
    if kwargs:
        config.exp_name += '_' + '_'.join(
            [str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str

    config.use_bert_input = use_bert_input
    config.bert_type = bert_type
    config.bert_trainable = bert_trainable
    config.bert_layer_num = bert_layer_num
    assert config.use_char_input or config.use_bert_input
    if config.use_bert_input:
        config.exp_name += '_{}_layer_{}_{}'.format(
            bert_type, bert_layer_num,
            'tune' if config.bert_trainable else 'fix')
    config.use_bichar_input = use_bichar_input
    if config.use_bichar_input:
        config.bichar_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='bichar'))
        config.bichar_vocab_size = len(config.bichar_vocab) + 2
        if bichar_embed_type:
            config.bichar_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=bichar_embed_type))
            config.bichar_embed_trainable = bichar_embed_trainable
            config.bichar_embed_dim = config.bichar_embeddings.shape[1]
        else:
            config.bichar_embeddings = None
            config.bichar_embed_trainable = True
        config.exp_name += '_bichar_{}_{}'.format(
            bichar_embed_type if bichar_embed_type else 'random',
            'tune' if config.bichar_embed_trainable else 'fix')
    config.use_word_input = use_word_input
    if config.use_word_input:
        config.word_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='word'))
        config.word_vocab_size = len(config.word_vocab) + 2
        if word_embed_type:
            config.word_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
        config.exp_name += '_word_{}_{}'.format(
            word_embed_type if word_embed_type else 'random',
            'tune' if config.word_embed_trainable else 'fix')
    config.use_charpos_input = use_charpos_input
    if config.use_charpos_input:
        config.charpos_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='charpos'))
        config.charpos_vocab_size = len(config.charpos_vocab) + 2
        if charpos_embed_type:
            config.charpos_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=charpos_embed_type))
            config.charpos_embed_trainable = charpos_embed_trainable
            config.charpos_embed_dim = config.charpos_embeddings.shape[1]
        else:
            config.charpos_embeddings = None
            config.charpos_embed_trainable = True
        config.exp_name += '_charpos_{}_{}'.format(
            charpos_embed_type if charpos_embed_type else 'random',
            'tune' if config.charpos_embed_trainable else 'fix')
    config.use_softword_input = use_softword_input
    if config.use_softword_input:
        config.exp_name += '_softword'
    config.use_dictfeat_input = use_dictfeat_input
    if config.use_dictfeat_input:
        config.exp_name += '_dictfeat'
    config.use_maxmatch_input = use_maxmatch_input
    if config.use_maxmatch_input:
        config.exp_name += '_maxmatch'

    # logger to log output of training process
    predict_log.update({
        'er_exp_name': config.exp_name,
        'er_batch_size': batch_size,
        'er_optimizer': optimizer_type,
        'er_epoch': n_epoch,
        'er_learning_rate': learning_rate,
        'er_other_params': kwargs
    })

    print('Logging Info - Experiment: %s' % config.exp_name)
    model = RecognitionModel(config, **kwargs)

    dev_data_type = 'dev'
    if predict_on_final_test:
        test_data_type = 'test_final'
    else:
        test_data_type = 'test'
    valid_generator = RecognitionDataGenerator(
        dev_data_type, config.batch_size, config.label_schema,
        config.label_to_one_hot[config.label_schema],
        config.vocab if config.use_char_input else None,
        config.bert_vocab_file(config.bert_type) if config.use_bert_input else
        None, config.bert_seq_len, config.bichar_vocab, config.word_vocab,
        config.use_word_input, config.charpos_vocab, config.use_softword_input,
        config.use_dictfeat_input, config.use_maxmatch_input)
    test_generator = RecognitionDataGenerator(
        test_data_type, config.batch_size, config.label_schema,
        config.label_to_one_hot[config.label_schema],
        config.vocab if config.use_char_input else None,
        config.bert_vocab_file(config.bert_type) if config.use_bert_input else
        None, config.bert_seq_len, config.bichar_vocab, config.word_vocab,
        config.use_word_input, config.charpos_vocab, config.use_softword_input,
        config.use_dictfeat_input, config.use_maxmatch_input)

    model_save_path = os.path.join(config.checkpoint_dir,
                                   '{}.hdf5'.format(config.exp_name))
    if not os.path.exists(model_save_path):
        raise FileNotFoundError(
            'Recognition model not exist: {}'.format(model_save_path))

    if swa_type is None:
        model.load_best_model()
    elif 'swa' in callbacks_to_add:
        model.load_swa_model(swa_type)
        predict_log['er_exp_name'] += '_{}'.format(swa_type)

    if predict_on_dev:
        print('Logging Info - Generate submission for valid data:')
        dev_pred_mentions = model.predict(valid_generator)
    else:
        dev_pred_mentions = None
    print('Logging Info - Generate submission for test data:')
    test_pred_mentions = model.predict(test_generator)

    return dev_pred_mentions, test_pred_mentions
Exemple #2
0
def train_model(genre, input_level, word_embed_type, word_embed_trainable, batch_size, learning_rate,
                optimizer_type, model_name, n_epoch=50, add_features=False, scale_features=False, overwrite=False,
                lr_range_test=False, callbacks_to_add=None, eval_on_train=False, **kwargs):
    config = ModelConfig()
    config.genre = genre
    config.input_level = input_level
    config.max_len = config.word_max_len[genre] if input_level == 'word' else config.char_max_len[genre]
    config.word_embed_type = word_embed_type
    config.word_embed_trainable = word_embed_trainable
    config.callbacks_to_add = callbacks_to_add or []
    config.add_features = add_features
    config.batch_size = batch_size
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.n_epoch = n_epoch
    config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, genre,
                                                     word_embed_type))
    vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, genre, input_level))
    config.idx2token = dict((idx, token) for token, idx in vocab.items())

    # experiment name configuration
    config.exp_name = '{}_{}_{}_{}_{}_{}_{}_{}'.format(genre, model_name, input_level, word_embed_type,
                                                       'tune' if word_embed_trainable else 'fix', batch_size,
                                                       '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()]),
                                                       optimizer_type)
    if config.add_features:
        config.exp_name = config.exp_name + '_feature_scaled' if scale_features else config.exp_name + '_featured'
    if len(config.callbacks_to_add) > 0:
        callback_str = '_' + '_'.join(config.callbacks_to_add)
        callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')
        config.exp_name += callback_str

    input_config = kwargs['input_config'] if 'input_config' in kwargs else 'token'  # input default is word embedding
    if input_config in ['cache_elmo', 'token_combine_cache_elmo']:
        # get elmo embedding based on cache, we first get a ELMoCache instance
        if 'elmo_model_type' in kwargs:
            elmo_model_type = kwargs['elmo_model_type']
            kwargs.pop('elmo_model_type')   # we don't need it in kwargs any more
        else:
            elmo_model_type = 'allennlp'
        if 'elmo_output_mode' in kwargs:
            elmo_output_mode = kwargs['elmo_output_mode']
            kwargs.pop('elmo_output_mode')  # we don't need it in kwargs any more
        else:
            elmo_output_mode ='elmo'
        elmo_cache = ELMoCache(options_file=config.elmo_options_file, weight_file=config.elmo_weight_file,
                               cache_dir=config.cache_dir, idx2token=config.idx2token,
                               max_sentence_length=config.max_len, elmo_model_type=elmo_model_type,
                               elmo_output_mode=elmo_output_mode)
    elif input_config in ['elmo_id', 'elmo_s', 'token_combine_elmo_id', 'token_combine_elmo_s']:
        # get elmo embedding using tensorflow_hub, we must provide a tfhub_url
        kwargs['elmo_model_url'] = config.elmo_model_url

    # logger to log output of training process
    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch,
                 'learning_rate': learning_rate, 'other_params': kwargs}

    print('Logging Info - Experiment: %s' % config.exp_name)
    if model_name == 'KerasInfersent':
        model = KerasInfersentModel(config, **kwargs)
    elif model_name == 'KerasEsim':
        model = KerasEsimModel(config, **kwargs)
    elif model_name == 'KerasDecomposable':
        model = KerasDecomposableAttentionModel(config, **kwargs)
    elif model_name == 'KerasSiameseBiLSTM':
        model = KerasSimaeseBiLSTMModel(config, **kwargs)
    elif model_name == 'KerasSiameseCNN':
        model = KerasSiameseCNNModel(config, **kwargs)
    elif model_name == 'KerasIACNN':
        model = KerasIACNNModel(config, **kwargs)
    elif model_name == 'KerasSiameseLSTMCNNModel':
        model = KerasSiameseLSTMCNNModel(config, **kwargs)
    elif model_name == 'KerasRefinedSSAModel':
        model = KerasRefinedSSAModel(config, **kwargs)
    else:
        raise ValueError('Model Name Not Understood : {}'.format(model_name))
    # model.summary()

    train_input, dev_input, test_input = None, None, None
    if lr_range_test:   # conduct lr range test to find optimal learning rate (not train model)
        train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features)
        dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features)
        model.lr_range_test(x_train=train_input['x'], y_train=train_input['y'], x_valid=dev_input['x'],
                            y_valid=dev_input['y'])
        return

    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()

        if input_config in ['cache_elmo', 'token_combine_cache_elmo']:
            train_input = ELMoGenerator(genre, input_level, 'train', config.batch_size, elmo_cache,
                                        return_data=(input_config == 'token_combine_cache_elmo'),
                                        return_features=config.add_features)
            dev_input = ELMoGenerator(genre, input_level, 'dev', config.batch_size, elmo_cache,
                                      return_data=(input_config == 'token_combine_cache_elmo'),
                                      return_features=config.add_features)
            model.train_with_generator(train_input, dev_input)
        else:
            train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features)
            dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features)
            model.train(x_train=train_input['x'], y_train=train_input['y'], x_valid=dev_input['x'],
                        y_valid=dev_input['y'])
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

    def eval_on_data(eval_with_generator, input_data, data_type):
        model.load_best_model()
        if eval_with_generator:
            acc = model.evaluate_with_generator(generator=input_data, y=input_data.input_label)
        else:
            acc = model.evaluate(x=input_data['x'], y=input_data['y'])
        train_log['%s_acc' % data_type] = acc

        swa_type = None
        if 'swa' in config.callbacks_to_add:
            swa_type = 'swa'
        elif 'swa_clr' in config.callbacks_to_add:
            swa_type = 'swa_clr'
        if swa_type:
            print('Logging Info - %s Model' % swa_type)
            model.load_swa_model(swa_type=swa_type)
            swa_acc = model.evaluate(x=input_data['x'], y=input_data['y'])
            train_log['%s_%s_acc' % (swa_type, data_type)] = swa_acc

        ensemble_type = None
        if 'sse' in config.callbacks_to_add:
            ensemble_type = 'sse'
        elif 'fge' in config.callbacks_to_add:
            ensemble_type = 'fge'
        if ensemble_type:
            print('Logging Info - %s Ensemble Model' % ensemble_type)
            ensemble_predict = {}
            for model_file in os.listdir(config.checkpoint_dir):
                if model_file.startswith(config.exp_name+'_%s' % ensemble_type):
                    match = re.match(r'(%s_%s_)([\d+])(.hdf5)' % (config.exp_name, ensemble_type), model_file)
                    model_id = int(match.group(2))
                    model_path = os.path.join(config.checkpoint_dir, model_file)
                    print('Logging Info: Loading {} ensemble model checkpoint: {}'.format(ensemble_type, model_file))
                    model.load_model(model_path)
                    ensemble_predict[model_id] = model.predict(x=input_data['x'])
            '''
            we expect the models saved towards the end of run may have better performance than models saved earlier 
            in the run, we sort the models so that the older models ('s id) are first.
            '''
            sorted_ensemble_predict = sorted(ensemble_predict.items(), key=lambda x: x[0], reverse=True)
            model_predicts = []
            for model_id, model_predict in sorted_ensemble_predict:
                single_acc = eval_acc(model_predict, input_data['y'])
                print('Logging Info - %s_single_%d_%s Acc : %f' % (ensemble_type, model_id, data_type, single_acc))
                train_log['%s_single_%d_%s_acc' % (ensemble_type, model_id, data_type)] = single_acc

                model_predicts.append(model_predict)
                ensemble_acc = eval_acc(np.mean(np.array(model_predicts), axis=0), input_data['y'])
                print('Logging Info - %s_ensemble_%d_%s Acc : %f' % (ensemble_type, model_id, data_type, ensemble_acc))
                train_log['%s_ensemble_%d_%s_acc' % (ensemble_type, model_id, data_type)] = ensemble_acc

    if eval_on_train:
        # might take a long time
        print('Logging Info - Evaluate over train data:')
        if input_config in ['cache_elmo', 'token_combine_cache_elmo']:
            train_input = ELMoGenerator(genre, input_level, 'train', config.batch_size, elmo_cache,
                                        return_data=(input_config == 'token_combine_cache_elmo'),
                                        return_features=config.add_features, return_label=False)
            eval_on_data(eval_with_generator=True, input_data=train_input, data_type='train')
        else:
            train_input = load_input_data(genre, input_level, 'train', input_config, config.add_features, scale_features)
            eval_on_data(eval_with_generator=False, input_data=train_input, data_type='train')

    print('Logging Info - Evaluate over valid data:')
    if input_config in ['cache_elmo', 'token_combine_cache_elmo']:
        dev_input = ELMoGenerator(genre, input_level, 'dev', config.batch_size, elmo_cache,
                                  return_data=(input_config == 'token_combine_cache_elmo'),
                                  return_features=config.add_features, return_label=False)
        eval_on_data(eval_with_generator=True, input_data=dev_input, data_type='dev')
    else:
        if dev_input is None:
            dev_input = load_input_data(genre, input_level, 'dev', input_config, config.add_features, scale_features)
        eval_on_data(eval_with_generator=False, input_data=dev_input, data_type='dev')

    print('Logging Info - Evaluate over test data:')
    if input_config in ['cache_elmo', 'token_combine_cache_elmo']:
        test_input = ELMoGenerator(genre, input_level, 'test', config.batch_size, elmo_cache,
                                   return_data=(input_config == 'token_combine_cache_elmo'),
                                   return_features=config.add_features, return_label=False)
        eval_on_data(eval_with_generator=True, input_data=test_input, data_type='test')
    else:
        if test_input is None:
            test_input = load_input_data(genre, input_level, 'test', input_config, config.add_features, scale_features)
        eval_on_data(eval_with_generator=False, input_data=test_input, data_type='test')

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, genre), log=train_log, mode='a')
    return train_log
def train(dataset,
          char_embed_type,
          char_trainable,
          fw_embed_type,
          fw_trainable,
          bw_embed_type,
          bw_trainable,
          gaze_embed_dim,
          n_step,
          n_layer,
          rnn_units,
          dropout,
          batch_size,
          n_epoch,
          optimizer,
          callbacks_to_add=None,
          overwrite=False):
    config = ModelConfig()
    config.char_embedding = np.load(
        format_filename(PROCESSED_DATA_DIR,
                        EMBEDDING_MATRIX_TEMPLATE,
                        dataset=dataset,
                        type=char_embed_type))
    config.char_trainable = char_trainable
    config.fw_bigram_embeddings = np.load(
        format_filename(PROCESSED_DATA_DIR,
                        EMBEDDING_MATRIX_TEMPLATE,
                        dataset=dataset,
                        type=fw_embed_type))
    config.fw_bigram_trainable = fw_trainable
    config.bw_bigram_embeddings = np.load(
        format_filename(PROCESSED_DATA_DIR,
                        EMBEDDING_MATRIX_TEMPLATE,
                        dataset=dataset,
                        type=bw_embed_type))
    config.bw_bigram_trainable = bw_trainable
    config.char_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        dataset=dataset,
                        level='char'))
    config.fw_bigram_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        dataset=dataset,
                        level='fw_bigram'))
    config.bw_bigram_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        dataset=dataset,
                        level='bw_bigram'))
    config.tag_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        dataset=dataset,
                        level='tag'))
    config.idx2tag = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        IDX2TOKEN_TEMPLATE,
                        dataset=dataset,
                        level='tag'))

    config.gaze_embed_dim = gaze_embed_dim
    config.n_step = n_step
    config.n_layer = n_layer
    config.rnn_units = rnn_units
    config.dropout = dropout
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.optimizer = optimizer
    config.callbacks_to_add = callbacks_to_add

    config.model_name = 'ggnn'
    config.exp_name = 'ggnn_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
        char_embed_type, 'tune' if char_trainable else 'fix', fw_embed_type,
        'tune' if fw_trainable else 'fix', bw_embed_type,
        'tune' if bw_embed_type else 'fix', gaze_embed_dim, n_step, n_layer,
        rnn_units, batch_size, n_epoch)
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str

    dev_generator = NERGenerator('dev', dataset, batch_size, config.char_vocab,
                                 config.fw_bigram_vocab,
                                 config.bw_bigram_vocab, config.tag_vocab)
    dev_input, _ = dev_generator.prepare_input(range(dev_generator.data_size))
    dev_tags = [text_example.tags for text_example in dev_generator.data]

    test_generator = NERGenerator('test', dataset, batch_size,
                                  config.char_vocab, config.fw_bigram_vocab,
                                  config.bw_bigram_vocab, config.tag_vocab)
    test_input, _ = test_generator.prepare_input(
        range(test_generator.data_size))
    test_tags = [text_example.tags for text_example in test_generator.data]

    config.n_gaze = dev_generator.n_gaze

    # logger to log output of training process
    train_log = {'exp_name': config.exp_name}
    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir,
                                   '{}.hdf5'.format(config.exp_name))
    model = MultiDigraph(config)

    if not os.path.exists(model_save_path) or overwrite:
        train_generator = NERGenerator('train', dataset, batch_size,
                                       config.char_vocab,
                                       config.fw_bigram_vocab,
                                       config.bw_bigram_vocab,
                                       config.tag_vocab)
        start_time = time.time()
        model.fit_generator(train_generator, dev_generator)
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' %
              time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S",
                                                time.gmtime(elapsed_time))

    model.load_best_model()
    print('Logging Info - Evaluate over valid data:')
    dev_score = model.evaluate(dev_input, dev_tags)
    train_log['dev_performance'] = dev_score
    print('Logging Info - Evaluate over test data:')
    test_score = model.evaluate(test_input, test_tags)
    train_log['test_performance'] = test_score

    if 'swa' in config.callbacks_to_add:
        model.load_swa_model()
        print('Logging Info - Evaluate over valid data based on swa model:')
        swa_dev_score = model.evaluate(dev_input, dev_tags)
        train_log['swa_dev_performance'] = swa_dev_score
        print('Logging Info - Evaluate over test data based on swa model:')
        swa_test_score = model.evaluate(test_input, test_tags)
        train_log['swa_test_performance'] = swa_test_score

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S',
                                           time.localtime())
    write_log(format_filename(LOG_DIR,
                              PERFORMANCE_LOG_TEMPLATE,
                              dataset=dataset),
              log=train_log,
              mode='a')
    del model
    gc.collect()
    K.clear_session()
Exemple #4
0
def link(model_name,
         dev_pred_mentions,
         test_pred_mentions,
         predict_log,
         batch_size=32,
         n_epoch=50,
         learning_rate=0.001,
         optimizer_type='adam',
         embed_type=None,
         embed_trainable=True,
         use_relative_pos=False,
         n_neg=1,
         omit_one_cand=True,
         callbacks_to_add=None,
         swa_type=None,
         predict_on_final_test=True,
         **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    if embed_type:
        config.embeddings = np.load(
            format_filename(PROCESSED_DATA_DIR,
                            EMBEDDING_MATRIX_TEMPLATE,
                            type=embed_type))
        config.embed_trainable = embed_trainable
    else:
        config.embeddings = None
        config.embed_trainable = True

    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]

    config.vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(
        format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))
    config.entity_desc = pickle_load(
        format_filename(PROCESSED_DATA_DIR, ENTITY_DESC_FILENAME))

    config.exp_name = '{}_{}_{}_{}_{}_{}'.format(
        model_name, embed_type if embed_type else 'random',
        'tune' if embed_trainable else 'fix', batch_size, optimizer_type,
        learning_rate)
    config.use_relative_pos = use_relative_pos
    if config.use_relative_pos:
        config.exp_name += '_rel'
    config.n_neg = n_neg
    if config.n_neg > 1:
        config.exp_name += '_neg_{}'.format(config.n_neg)
    config.omit_one_cand = omit_one_cand
    if not config.omit_one_cand:
        config.exp_name += '_not_omit'
    if kwargs:
        config.exp_name += '_' + '_'.join(
            [str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str

    # logger to log output of training process
    predict_log.update({
        'el_exp_name': config.exp_name,
        'el_batch_size': batch_size,
        'el_optimizer': optimizer_type,
        'el_epoch': n_epoch,
        'el_learning_rate': learning_rate,
        'el_other_params': kwargs
    })

    print('Logging Info - Experiment: %s' % config.exp_name)
    model = LinkModel(config, **kwargs)

    model_save_path = os.path.join(config.checkpoint_dir,
                                   '{}.hdf5'.format(config.exp_name))
    if not os.path.exists(model_save_path):
        raise FileNotFoundError(
            'Recognition model not exist: {}'.format(model_save_path))
    if swa_type is None:
        model.load_best_model()
    elif 'swa' in callbacks_to_add:
        model.load_swa_model(swa_type)
        predict_log['er_exp_name'] += '_{}'.format(swa_type)

    dev_data_type = 'dev'
    dev_data = load_data(dev_data_type)
    dev_text_data, dev_gold_mention_entities = [], []
    for data in dev_data:
        dev_text_data.append(data['text'])
        dev_gold_mention_entities.append(data['mention_data'])

    if predict_on_final_test:
        test_data_type = 'test_final'
    else:
        test_data_type = 'test'
    test_data = load_data(test_data_type)
    test_text_data = [data['text'] for data in test_data]

    if dev_pred_mentions is not None:
        print(
            'Logging Info - Evaluate over valid data based on predicted mention:'
        )
        r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions,
                                  dev_gold_mention_entities)
        dev_performance = 'dev_performance' if swa_type is None else '%s_dev_performance' % swa_type
        predict_log[dev_performance] = (r, p, f1)
    print('Logging Info - Generate submission for test data:')
    test_pred_mention_entities = model.predict(test_text_data,
                                               test_pred_mentions)
    test_submit_file = predict_log[
        'er_exp_name'] + '_' + config.exp_name + '_%s%ssubmit.json' % (
            swa_type + '_' if swa_type else '',
            'final_' if predict_on_final_test else '')
    submit_result(test_submit_file, test_data, test_pred_mention_entities)

    predict_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S',
                                             time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step'),
              log=predict_log,
              mode='a')
    return predict_log
Exemple #5
0
def train(train_d,dev_d,test_d,kfold,dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type,
          batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=True):
    config = ModelConfig()
    config.neighbor_sample_size = neighbor_sample_size
    config.embed_dim = embed_dim
    config.n_depth = n_depth
    config.l2_weight = l2_weight
    config.dataset=dataset
    config.K_Fold=kfold
    config.lr = lr
    config.optimizer = get_optimizer(optimizer_type, lr)
    config.batch_size = batch_size
    config.aggregator_type = aggregator_type
    config.n_epoch = n_epoch
    config.callbacks_to_add = callbacks_to_add

    #drug id
    #should be SMILES
    config.drug_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                             DRUG_VOCAB_TEMPLATE,
                                                             dataset=dataset)))

    #entity id
    config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                               ENTITY_VOCAB_TEMPLATE,
                                                               dataset=dataset)))

    #relation id
    #string
    config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                                 RELATION_VOCAB_TEMPLATE,
                                                                 dataset=dataset)))
    #chosen entity matrix
    config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE,
                                                dataset=dataset))


    config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE,
                                                  dataset=dataset))

    config.drug_smile = np.load(format_filename(PROCESSED_DATA_DIR, DRUG_SMILE_TEMPLATE),allow_pickle=True)

    config.smile_hash = np.load(format_filename(PROCESSED_DATA_DIR, SMILE_HASH),allow_pickle=True)

    config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \
                      f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \
                      f'batch_size_{batch_size}_epoch_{n_epoch}'
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')#去掉了这两种方式使用swa得方式平均
    config.exp_name += callback_str

    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type,
                 'epoch': n_epoch, 'learning_rate': lr}
    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
    model = DDKG(config)
    #model = KGCN(config)

    train_data=np.array(train_d)
    valid_data=np.array(dev_d)
    test_data=np.array(test_d)
    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()
        print([train_data[:, :1], train_data[:, 1:2]])
        model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3],
                  x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3])
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S",
                                                                 time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

    print('Logging Info - Evaluate over valid data:')
    model.load_best_model()
    auc, acc, f1,aupr, fpr, tpr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])

    print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_aupr: {aupr}'
          )
    train_log['dev_auc'] = auc
    train_log['dev_acc'] = acc
    train_log['dev_f1'] = f1
    train_log['dev_aupr']=aupr
    train_log['k_fold']=kfold
    train_log['dataset']=dataset
    train_log['aggregate_type']=config.aggregator_type
    train_log['dev_fpr'] = fpr
    train_log['dev_tpr'] = tpr
    if 'swa' in config.callbacks_to_add:
        model.load_swa_model()
        print('Logging Info - Evaluate over valid data based on swa model:')
        auc, acc, f1,aupr, fpr, tpr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])

        train_log['swa_dev_auc'] = auc
        train_log['swa_dev_acc'] = acc
        train_log['swa_dev_f1'] = f1
        train_log['swa_dev_aupr']=aupr
        train_log['swa_dev_fpr'] = fpr
        train_log['swa_dev_tpr'] = tpr
        print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, swa_dev_aupr: {aupr}') #修改输出指标
    print('Logging Info - Evaluate over test data:')
    model.load_best_model()
    auc, acc, f1, aupr,fpr, tpr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])

    train_log['test_auc'] = auc
    train_log['test_acc'] = acc
    train_log['test_f1'] = f1
    train_log['test_aupr'] =aupr
    train_log['test_fpr'] = fpr
    train_log['test_tpr'] = tpr
    print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_aupr: {aupr}, test_fpr: {fpr}', )
    if 'swa' in config.callbacks_to_add:
        model.load_swa_model()
        print('Logging Info - Evaluate over test data based on swa model:')
        auc, acc, f1,aupr, fpr, tpr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
        train_log['swa_test_auc'] = auc
        train_log['swa_test_acc'] = acc
        train_log['swa_test_f1'] = f1
        train_log['swa_test_aupr'] = aupr
        train_log['swa_test_fpr'] = fpr
        train_log['swa_test_tpr'] = tpr
        print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, swa_test_aupr: {aupr}')
    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a')
    del model
    gc.collect()
    K.clear_session()
    return train_log
Exemple #6
0
def train(dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type,
          batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=False):
    config = ModelConfig()
    config.neighbor_sample_size = neighbor_sample_size
    config.embed_dim = embed_dim
    config.n_depth = n_depth
    config.l2_weight = l2_weight
    config.lr = lr
    config.optimizer = get_optimizer(optimizer_type, lr)
    config.batch_size = batch_size
    config.aggregator_type = aggregator_type
    config.n_epoch = n_epoch
    config.callbacks_to_add = callbacks_to_add

    config.user_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                             USER_VOCAB_TEMPLATE,
                                                             dataset=dataset)))
    config.item_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                             ITEM_VOCAB_TEMPLATE,
                                                             dataset=dataset)))
    config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                               ENTITY_VOCAB_TEMPLATE,
                                                               dataset=dataset)))
    config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
                                                                 RELATION_VOCAB_TEMPLATE,
                                                                 dataset=dataset)))
    config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE,
                                                dataset=dataset))
    config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE,
                                                  dataset=dataset))

    config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \
                      f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \
                      f'batch_size_{batch_size}_epoch_{n_epoch}'
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')
    config.exp_name += callback_str

    # logger to log output of training process
    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type,
                 'epoch': n_epoch, 'learning_rate': lr}
    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
    model = KGCN(config)
    train_data = load_data(dataset, 'train')
    valid_data = load_data(dataset, 'dev')
    test_data = load_data(dataset, 'test')

    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()
        model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3],
                  x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3])
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S",
                                                                 time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

    print('Logging Info - Evaluate over valid data:')
    model.load_best_model()
    auc, acc, f1 = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])

    user_list, train_record, valid_record, item_set, k_list = topk_settings(train_data,
                                                                            valid_data,
                                                                            config.item_vocab_size)
    topk_p, topk_r = topk_eval(model, user_list, train_record, valid_record, item_set, k_list)
    print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_topk_p: {topk_p}, '
          f'dev_topk_r: {topk_r}')
    train_log['dev_auc'] = auc
    train_log['dev_acc'] = acc
    train_log['dev_f1'] = f1
    train_log['dev_topk_p'] = topk_p
    train_log['dev_topk_r'] = topk_r

    if 'swa' in config.callbacks_to_add:
        model.load_swa_model()
        print('Logging Info - Evaluate over valid data based on swa model:')
        auc, acc, f1 = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])
        topk_p, topk_r = topk_eval(model, user_list, train_record, valid_record, item_set, k_list)
        train_log['swa_dev_auc'] = auc
        train_log['swa_dev_acc'] = acc
        train_log['swa_dev_f1'] = f1
        train_log['swa_dev_topk_p'] = topk_p
        train_log['swa_dev_topk_r'] = topk_r
        print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, '
              f'swa_dev_topk_p: {topk_p}, swa_dev_topk_r: {topk_r}')

    print('Logging Info - Evaluate over test data:')
    model.load_best_model()
    auc, acc, f1 = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])

    user_list, train_record, test_record, item_set, k_list = topk_settings(train_data,
                                                                           test_data,
                                                                           config.item_vocab_size)
    topk_p, topk_r = topk_eval(model, user_list, train_record, test_record, item_set, k_list)
    train_log['test_auc'] = auc
    train_log['test_acc'] = acc
    train_log['test_f1'] = f1
    train_log['test_topk_p'] = topk_p
    train_log['test_topk_r'] = topk_r
    print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_topk_p: {topk_p}, '
          f'test_topk_r: {topk_r}')

    if 'swa' in config.callbacks_to_add:
        model.load_swa_model()
        print('Logging Info - Evaluate over test data based on swa model:')
        auc, acc, f1 = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
        topk_p, topk_r = topk_eval(model, user_list, train_record, test_record, item_set, k_list)
        train_log['swa_test_auc'] = auc
        train_log['swa_test_acc'] = acc
        train_log['swa_test_f1'] = f1
        train_log['swa_test_topk_p'] = topk_p
        train_log['swa_test_topk_r'] = topk_r
        print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, '
              f'swa_test_topk_p: {topk_p}, swa_test_topk_r: {topk_r}')

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a')
    del model
    gc.collect()
    K.clear_session()
Exemple #7
0
def train_recognition(model_name, label_schema='BIOES', batch_size=32, n_epoch=50, learning_rate=0.001,
                      optimizer_type='adam', use_char_input=True, embed_type=None, embed_trainable=True,
                      use_bert_input=False, bert_type='bert', bert_trainable=True, bert_layer_num=1,
                      use_bichar_input=False, bichar_embed_type=None, bichar_embed_trainable=True,
                      use_word_input=False, word_embed_type=None, word_embed_trainable=True,
                      use_charpos_input=False, charpos_embed_type=None, charpos_embed_trainable=True,
                      use_softword_input=False, use_dictfeat_input=False, use_maxmatch_input=False,
                      callbacks_to_add=None, overwrite=False, swa_start=3, early_stopping_patience=3, **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.label_schema = label_schema
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    config.use_char_input = use_char_input
    if embed_type:
        config.embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, type=embed_type))
        config.embed_trainable = embed_trainable
        config.embed_dim = config.embeddings.shape[1]
    else:
        config.embeddings = None
        config.embed_trainable = True

    config.callbacks_to_add = callbacks_to_add or ['modelcheckpoint', 'earlystopping']
    if 'swa' in config.callbacks_to_add:
        config.swa_start = swa_start
        config.early_stopping_patience = early_stopping_patience

    config.vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))

    if config.use_char_input:
        config.exp_name = '{}_{}_{}_{}_{}_{}_{}'.format(model_name, config.embed_type if config.embed_type else 'random',
                                                        'tune' if config.embed_trainable else 'fix', batch_size,
                                                        optimizer_type, learning_rate, label_schema)
    else:
        config.exp_name = '{}_{}_{}_{}_{}'.format(model_name, batch_size, optimizer_type, learning_rate, label_schema)
    if config.n_epoch != 50:
        config.exp_name += '_{}'.format(config.n_epoch)
    if kwargs:
        config.exp_name += '_' + '_'.join([str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')
    config.exp_name += callback_str

    config.use_bert_input = use_bert_input
    config.bert_type = bert_type
    config.bert_trainable = bert_trainable
    config.bert_layer_num = bert_layer_num
    assert config.use_char_input or config.use_bert_input
    if config.use_bert_input:
        config.exp_name += '_{}_layer_{}_{}'.format(bert_type, bert_layer_num, 'tune' if config.bert_trainable else 'fix')
    config.use_bichar_input = use_bichar_input
    if config.use_bichar_input:
        config.bichar_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='bichar'))
        config.bichar_vocab_size = len(config.bichar_vocab) + 2
        if bichar_embed_type:
            config.bichar_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                               type=bichar_embed_type))
            config.bichar_embed_trainable = bichar_embed_trainable
            config.bichar_embed_dim = config.bichar_embeddings.shape[1]
        else:
            config.bichar_embeddings = None
            config.bichar_embed_trainable = True
        config.exp_name += '_bichar_{}_{}'.format(bichar_embed_type if bichar_embed_type else 'random',
                                                  'tune' if config.bichar_embed_trainable else 'fix')
    config.use_word_input = use_word_input
    if config.use_word_input:
        config.word_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word'))
        config.word_vocab_size = len(config.word_vocab) + 2
        if word_embed_type:
            config.word_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                             type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
        config.exp_name += '_word_{}_{}'.format(word_embed_type if word_embed_type else 'random',
                                                'tune' if config.word_embed_trainable else 'fix')
    config.use_charpos_input = use_charpos_input
    if config.use_charpos_input:
        config.charpos_vocab = pickle_load(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='charpos'))
        config.charpos_vocab_size = len(config.charpos_vocab) + 2
        if charpos_embed_type:
            config.charpos_embeddings = np.load(format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE,
                                                                type=charpos_embed_type))
            config.charpos_embed_trainable = charpos_embed_trainable
            config.charpos_embed_dim = config.charpos_embeddings.shape[1]
        else:
            config.charpos_embeddings = None
            config.charpos_embed_trainable = True
        config.exp_name += '_charpos_{}_{}'.format(charpos_embed_type if charpos_embed_type else 'random',
                                                   'tune' if config.charpos_embed_trainable else 'fix')
    config.use_softword_input = use_softword_input
    if config.use_softword_input:
        config.exp_name += '_softword'
    config.use_dictfeat_input = use_dictfeat_input
    if config.use_dictfeat_input:
        config.exp_name += '_dictfeat'
    config.use_maxmatch_input = use_maxmatch_input
    if config.use_maxmatch_input:
        config.exp_name += '_maxmatch'

    # logger to log output of training process
    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'epoch': n_epoch,
                 'learning_rate': learning_rate, 'other_params': kwargs}

    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
    model = RecognitionModel(config, **kwargs)

    train_data_type, dev_data_type = 'train', 'dev'
    train_generator = RecognitionDataGenerator(train_data_type, config.batch_size, config.label_schema,
                                               config.label_to_one_hot[config.label_schema],
                                               config.vocab if config.use_char_input else None,
                                               config.bert_vocab_file(config.bert_type) if config.use_bert_input else None,
                                               config.bert_seq_len, config.bichar_vocab, config.word_vocab,
                                               config.use_word_input, config.charpos_vocab, config.use_softword_input,
                                               config.use_dictfeat_input, config.use_maxmatch_input)
    valid_generator = RecognitionDataGenerator(dev_data_type, config.batch_size, config.label_schema,
                                               config.label_to_one_hot[config.label_schema],
                                               config.vocab if config.use_char_input else None,
                                               config.bert_vocab_file(config.bert_type) if config.use_bert_input else None,
                                               config.bert_seq_len, config.bichar_vocab, config.word_vocab,
                                               config.use_word_input, config.charpos_vocab, config.use_softword_input,
                                               config.use_dictfeat_input, config.use_maxmatch_input)

    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()
        model.train(train_generator, valid_generator)
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

    model.load_best_model()

    print('Logging Info - Evaluate over valid data:')
    r, p, f1 = model.evaluate(valid_generator)
    train_log['dev_performance'] = (r, p, f1)

    swa_type = None
    if 'swa' in config.callbacks_to_add:
        swa_type = 'swa'
    elif 'swa_clr' in config.callbacks_to_add:
        swa_type = 'swa_clr'
    if swa_type:
        model.load_swa_model(swa_type)
        print('Logging Info - Evaluate over valid data based on swa model:')
        r, p, f1 = model.evaluate(valid_generator)
        train_log['swa_dev_performance'] = (r, p, f1)

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_er'), log=train_log, mode='a')

    del model
    gc.collect()
    K.clear_session()
def prepare_config(model_type='bert-base-uncased',
                   input_type='name_desc',
                   use_multi_task=True,
                   use_harl=False,
                   use_hal=False,
                   cate_embed_dim=100,
                   use_word_input=False,
                   word_embed_type='w2v',
                   word_embed_trainable=True,
                   word_embed_dim=300,
                   use_bert_input=True,
                   bert_trainable=True,
                   use_bert_type='pooler',
                   n_last_hidden_layer=0,
                   dense_after_bert=True,
                   use_pair_input=True,
                   max_len=None,
                   share_father_pred='no',
                   use_mask_for_cate2=False,
                   use_mask_for_cate3=True,
                   cate3_mask_type='cate1',
                   cate1_loss_weight=1.,
                   cate2_loss_weight=1.,
                   cate3_loss_weight=1.,
                   batch_size=32,
                   predict_batch_size=32,
                   n_epoch=50,
                   learning_rate=2e-5,
                   optimizer='adam',
                   use_focal_loss=False,
                   callbacks_to_add=None,
                   swa_start=15,
                   early_stopping_patience=5,
                   max_lr=6e-5,
                   min_lr=1e-5,
                   train_on_cv=False,
                   cv_random_state=42,
                   cv_fold=5,
                   cv_index=0,
                   exchange_pair=False,
                   exchange_threshold=0.1,
                   use_pseudo_label=False,
                   pseudo_path=None,
                   pseudo_random_state=42,
                   pseudo_rate=0.1,
                   pseudo_index=0,
                   pseudo_name=None,
                   exp_name=None):
    config = ModelConfig()
    config.model_type = model_type
    config.input_type = input_type
    config.use_multi_task = use_multi_task
    config.use_harl = use_harl
    config.use_hal = use_hal
    assert not (config.use_harl and config.use_hal)
    config.cate_embed_dim = cate_embed_dim

    config.use_word_input = use_word_input
    config.word_embed_type = word_embed_type
    if config.use_word_input:
        if word_embed_type:
            config.word_embeddings = np.load(
                format_filename(PROCESSED_DATA_DIR,
                                EMBEDDING_MATRIX_TEMPLATE,
                                type=word_embed_type))
            config.word_embed_trainable = word_embed_trainable
            config.word_embed_dim = config.word_embeddings.shape[1]
        else:
            config.word_embeddings = None
            config.word_embed_trainable = True
            config.word_embed_dim = word_embed_dim
        config.word_vocab = pickle_load(
            format_filename(PROCESSED_DATA_DIR,
                            VOCABULARY_TEMPLATE,
                            level='word'))
        config.word_vocab_size = len(
            config.word_vocab) + 2  # 0: mask, 1: padding
    else:
        config.word_vocab = None

    config.use_bert_input = use_bert_input
    config.bert_trainable = bert_trainable
    if config.use_bert_input:
        config.use_bert_type = use_bert_type
        config.dense_after_bert = dense_after_bert
        if config.use_bert_type in ['hidden', 'hidden_pooler'] or \
                (config.use_multi_task and (config.use_harl or config.use_hal)):
            config.output_hidden_state = True
            config.n_last_hidden_layer = n_last_hidden_layer
        else:
            config.output_hidden_state = False
            config.n_last_hidden_layer = 0
    if config.input_type == 'name_desc':
        config.use_pair_input = use_pair_input
    else:
        config.use_pair_input = False

    if config.use_bert_input and max_len is None:
        config.max_len = MAX_LEN[input_type]
    else:
        config.max_len = max_len

    config.cate1_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate1'))
    config.cate2_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate2'))
    config.cate3_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE,
                        level='cate3'))
    config.all_cate_vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        VOCABULARY_TEMPLATE,
                        level='all_cate'))
    config.idx2cate1 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1'))
    config.idx2cate2 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2'))
    config.idx2cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3'))
    config.idx2all_cate = pickle_load(
        format_filename(PROCESSED_DATA_DIR,
                        IDX2TOKEN_TEMPLATE,
                        level='all_cate'))
    config.cate1_to_cate2 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT))
    config.cate2_to_cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT))
    config.cate1_to_cate3 = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT))
    config.cate1_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT))
    config.cate2_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT))
    config.cate3_count_dict = pickle_load(
        format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT))
    config.n_cate1 = len(config.cate1_vocab)
    config.n_cate2 = len(config.cate2_vocab)
    config.n_cate3 = len(config.cate3_vocab)
    config.n_all_cate = len(config.all_cate_vocab)

    if config.use_multi_task and (config.use_harl or config.use_hal):
        config.share_father_pred = 'no'
        config.use_mask_for_cate2 = False
        config.use_mask_for_cate3 = False
        config.cate3_mask_type = None
    else:
        config.share_father_pred = share_father_pred
        config.use_mask_for_cate2 = use_mask_for_cate2
        config.use_mask_for_cate3 = use_mask_for_cate3
        config.cate3_mask_type = cate3_mask_type
        # if config.use_mask_for_cate2:
        if config.use_mask_for_cate3:
            if config.cate3_mask_type == 'cate1':
                config.cate_to_cate3 = pickle_load(
                    format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT))
            elif config.cate3_mask_type == 'cate2':
                config.cate_to_cate3 = pickle_load(
                    format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT))
    config.cate1_loss_weight = cate1_loss_weight
    config.cate2_loss_weight = cate2_loss_weight
    config.cate3_loss_weight = cate3_loss_weight

    config.batch_size = batch_size
    config.predict_batch_size = predict_batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = optimizer
    config.learning_rate = learning_rate
    config.use_focal_loss = use_focal_loss
    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]
    if 'swa' in config.callbacks_to_add:
        config.swa_start = swa_start
        config.early_stopping_patience = early_stopping_patience
    for lr_scheduler in [
            'clr', 'sgdr', 'clr_1', 'clr_2', 'warm_up', 'swa_clr'
    ]:
        if lr_scheduler in config.callbacks_to_add:
            config.max_lr = max_lr
            config.min_lr = min_lr

    config.train_on_cv = train_on_cv
    if config.train_on_cv:
        config.cv_random_state = cv_random_state
        config.cv_fold = cv_fold
        config.cv_index = cv_index

    config.exchange_pair = exchange_pair
    if config.exchange_pair:
        config.exchange_threshold = exchange_threshold

    config.use_pseudo_label = use_pseudo_label
    if config.use_pseudo_label:
        config.pseudo_path = pseudo_path
        config.pseudo_random_state = pseudo_random_state
        config.pseudo_rate = pseudo_rate
        config.pseudo_index = pseudo_index

    # build experiment name from parameter configuration
    config.exp_name = f'{config.model_type}_{config.input_type}'
    if config.use_pair_input:
        config.exp_name += '_pair'
    config.exp_name += f'_len_{config.max_len}'
    if config.use_word_input:
        config.exp_name += f"_word_{config.word_embed_type}_{'tune' if config.word_embed_trainable else 'fix'}"
    if config.use_bert_input:
        config.exp_name += f"_bert_{config.use_bert_type}_{'tune' if config.bert_trainable else 'fix'}"
        if config.output_hidden_state:
            config.exp_name += f'_hid_{config.n_last_hidden_layer}'
        if config.dense_after_bert:
            config.exp_name += '_dense'
    if config.use_multi_task:
        if config.use_harl:
            config.exp_name += f'_harl_{config.cate_embed_dim}'
        elif config.use_hal:
            config.exp_name += f'_hal_{config.cate_embed_dim}'
        config.exp_name += f'_{config.cate1_loss_weight}_{config.cate2_loss_weight}_{config.cate3_loss_weight}'
    else:
        config.exp_name += f'_not_multi_task'
    if config.share_father_pred in ['after', 'before']:
        config.exp_name += f'_{config.share_father_pred}'
    if config.use_mask_for_cate2:
        config.exp_name += f'_mask_cate2'
    if config.use_mask_for_cate3:
        config.exp_name += f'_mask_cate3_with_{config.cate3_mask_type}'
    if config.use_focal_loss:
        config.exp_name += f'_focal'
    config.exp_name += f'_{config.optimizer}_{config.learning_rate}_{config.batch_size}_{config.n_epoch}'
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str
    if config.train_on_cv:
        config.exp_name += f'_{config.cv_random_state}_{config.cv_fold}_{config.cv_index}'
    if config.exchange_pair:
        config.exp_name += f"_ex_pair_{config.exchange_threshold}"
    if config.use_pseudo_label:
        if pseudo_name:
            config.exp_name += f"_{pseudo_name}_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"
        elif 'dev' in config.pseudo_path:
            config.exp_name += f"_dev_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"
        else:
            config.exp_name += f"_test_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}"

    if exp_name:
        config.exp_name = exp_name

    return config
Exemple #9
0
def train_link(model_name,
               batch_size=32,
               n_epoch=50,
               learning_rate=0.001,
               optimizer_type='adam',
               embed_type=None,
               embed_trainable=True,
               callbacks_to_add=None,
               use_relative_pos=False,
               n_neg=1,
               omit_one_cand=True,
               overwrite=False,
               swa_start=5,
               early_stopping_patience=3,
               **kwargs):
    config = ModelConfig()
    config.model_name = model_name
    config.batch_size = batch_size
    config.n_epoch = n_epoch
    config.learning_rate = learning_rate
    config.optimizer = get_optimizer(optimizer_type, learning_rate)
    config.embed_type = embed_type
    if embed_type:
        config.embeddings = np.load(
            format_filename(PROCESSED_DATA_DIR,
                            EMBEDDING_MATRIX_TEMPLATE,
                            type=embed_type))
        config.embed_trainable = embed_trainable
    else:
        config.embeddings = None
        config.embed_trainable = True

    config.callbacks_to_add = callbacks_to_add or [
        'modelcheckpoint', 'earlystopping'
    ]
    if 'swa' in config.callbacks_to_add:
        config.swa_start = swa_start
        config.early_stopping_patience = early_stopping_patience

    config.vocab = pickle_load(
        format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='char'))
    config.vocab_size = len(config.vocab) + 2
    config.mention_to_entity = pickle_load(
        format_filename(PROCESSED_DATA_DIR, MENTION_TO_ENTITY_FILENAME))
    config.entity_desc = pickle_load(
        format_filename(PROCESSED_DATA_DIR, ENTITY_DESC_FILENAME))

    config.exp_name = '{}_{}_{}_{}_{}_{}'.format(
        model_name, embed_type if embed_type else 'random',
        'tune' if config.embed_trainable else 'fix', batch_size,
        optimizer_type, learning_rate)
    config.use_relative_pos = use_relative_pos
    if config.use_relative_pos:
        config.exp_name += '_rel'
    config.n_neg = n_neg
    if config.n_neg > 1:
        config.exp_name += '_neg_{}'.format(config.n_neg)
    config.omit_one_cand = omit_one_cand
    if not config.omit_one_cand:
        config.exp_name += '_not_omit'
    if kwargs:
        config.exp_name += '_' + '_'.join(
            [str(k) + '_' + str(v) for k, v in kwargs.items()])
    callback_str = '_' + '_'.join(config.callbacks_to_add)
    callback_str = callback_str.replace('_modelcheckpoint',
                                        '').replace('_earlystopping', '')
    config.exp_name += callback_str

    # logger to log output of training process
    train_log = {
        'exp_name': config.exp_name,
        'batch_size': batch_size,
        'optimizer': optimizer_type,
        'epoch': n_epoch,
        'learning_rate': learning_rate,
        'other_params': kwargs
    }

    print('Logging Info - Experiment: %s' % config.exp_name)
    model_save_path = os.path.join(config.checkpoint_dir,
                                   '{}.hdf5'.format(config.exp_name))
    model = LinkModel(config, **kwargs)

    train_data_type, dev_data_type = 'train', 'dev'
    train_generator = LinkDataGenerator(
        train_data_type, config.vocab, config.mention_to_entity,
        config.entity_desc, config.batch_size, config.max_desc_len,
        config.max_erl_len, config.use_relative_pos, config.n_neg,
        config.omit_one_cand)
    dev_data = load_data(dev_data_type)

    if not os.path.exists(model_save_path) or overwrite:
        start_time = time.time()
        model.train(train_generator, dev_data)
        elapsed_time = time.time() - start_time
        print('Logging Info - Training time: %s' %
              time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        train_log['train_time'] = time.strftime("%H:%M:%S",
                                                time.gmtime(elapsed_time))

    model.load_best_model()
    dev_text_data, dev_pred_mentions, dev_gold_mention_entities = [], [], []
    for data in dev_data:
        dev_text_data.append(data['text'])
        dev_pred_mentions.append(data['mention_data'])
        dev_gold_mention_entities.append(data['mention_data'])
    print('Logging Info - Evaluate over valid data:')
    r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions,
                              dev_gold_mention_entities)
    train_log['dev_performance'] = (r, p, f1)

    swa_type = None
    if 'swa' in config.callbacks_to_add:
        swa_type = 'swa'
    elif 'swa_clr' in config.callbacks_to_add:
        swa_type = 'swa_clr'
    if swa_type:
        model.load_swa_model(swa_type)
        print('Logging Info - Evaluate over valid data based on swa model:')
        r, p, f1 = model.evaluate(dev_text_data, dev_pred_mentions,
                                  dev_gold_mention_entities)
        train_log['swa_dev_performance'] = (r, p, f1)

    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S',
                                           time.localtime())
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG, model_type='2step_el'),
              log=train_log,
              mode='a')
    del model
    gc.collect()
    K.clear_session()