コード例 #1
0
ファイル: predict.py プロジェクト: chdwlx/comments_classifier
 def __init__(self, data_manager, logger):
     hidden_dim = classifier_config['hidden_dim']
     classifier = classifier_config['classifier']
     self.dataManager = data_manager
     seq_length = data_manager.max_sequence_length
     num_classes = data_manager.max_label_number
     embedding_dim = data_manager.embedding_dim
     self.logger = logger
     # 卷集核的个数
     num_filters = classifier_config['num_filters']
     checkpoints_dir = classifier_config['checkpoints_dir']
     logger.info('loading model parameter')
     if classifier == 'textcnn':
         from engines.models.textcnn import TextCNN
         self.model = TextCNN(seq_length, num_filters, num_classes,
                              embedding_dim)
     elif classifier == 'textrcnn':
         from engines.models.textrcnn import TextRCNN
         self.model = TextRCNN(seq_length, num_classes, hidden_dim,
                               embedding_dim)
     else:
         raise Exception('config model is not exist')
     # 实例化Checkpoint,设置恢复对象为新建立的模型
     checkpoint = tf.train.Checkpoint(model=self.model)
     # 从文件恢复模型参数
     checkpoint.restore(tf.train.latest_checkpoint(checkpoints_dir))
     logger.info('loading model successfully')
コード例 #2
0
ファイル: predict.py プロジェクト: yuedy/text_classifier
    def __init__(self, data_manager, logger):
        hidden_dim = classifier_config['hidden_dim']
        classifier = classifier_config['classifier']
        self.dataManager = data_manager
        self.seq_length = data_manager.max_sequence_length
        num_classes = data_manager.max_label_number
        self.embedding_dim = data_manager.embedding_dim
        vocab_size = data_manager.vocab_size

        self.logger = logger
        # 卷集核的个数
        num_filters = classifier_config['num_filters']
        self.checkpoints_dir = classifier_config['checkpoints_dir']
        self.embedding_method = classifier_config['embedding_method']
        if self.embedding_method == 'Bert':
            from transformers import TFBertModel
            self.bert_model = TFBertModel.from_pretrained('bert-base-multilingual-cased')
        logger.info('loading model parameter')
        if classifier == 'textcnn':
            from engines.models.textcnn import TextCNN
            self.model = TextCNN(self.seq_length, num_filters, num_classes, self.embedding_dim, vocab_size)
        elif classifier == 'textrcnn':
            from engines.models.textrcnn import TextRCNN
            self.model = TextRCNN(self.seq_length, num_classes, hidden_dim, self.embedding_dim, vocab_size)
        elif classifier == 'textrnn':
            from engines.models.textrnn import TextRNN
            self.model = TextRNN(self.seq_length, num_classes, hidden_dim, self.embedding_dim, vocab_size)
        else:
            raise Exception('config model is not exist')
        # 实例化Checkpoint,设置恢复对象为新建立的模型
        checkpoint = tf.train.Checkpoint(model=self.model)
        # 从文件恢复模型参数
        checkpoint.restore(tf.train.latest_checkpoint(self.checkpoints_dir))
        logger.info('loading model successfully')
コード例 #3
0
ファイル: predict.py プロジェクト: chdwlx/comments_classifier
class Predictor:
    def __init__(self, data_manager, logger):
        hidden_dim = classifier_config['hidden_dim']
        classifier = classifier_config['classifier']
        self.dataManager = data_manager
        seq_length = data_manager.max_sequence_length
        num_classes = data_manager.max_label_number
        embedding_dim = data_manager.embedding_dim
        self.logger = logger
        # 卷集核的个数
        num_filters = classifier_config['num_filters']
        checkpoints_dir = classifier_config['checkpoints_dir']
        logger.info('loading model parameter')
        if classifier == 'textcnn':
            from engines.models.textcnn import TextCNN
            self.model = TextCNN(seq_length, num_filters, num_classes,
                                 embedding_dim)
        elif classifier == 'textrcnn':
            from engines.models.textrcnn import TextRCNN
            self.model = TextRCNN(seq_length, num_classes, hidden_dim,
                                  embedding_dim)
        else:
            raise Exception('config model is not exist')
        # 实例化Checkpoint,设置恢复对象为新建立的模型
        checkpoint = tf.train.Checkpoint(model=self.model)
        # 从文件恢复模型参数
        checkpoint.restore(tf.train.latest_checkpoint(checkpoints_dir))
        logger.info('loading model successfully')

    def predict_one(self, sentence):
        """
        对输入的句子分类预测
        :param sentence:
        :return:
        """
        reverse_classes = {
            class_id: class_name
            for class_name, class_id in self.dataManager.class_id.items()
        }
        vector = self.dataManager.prepare_single_sentence(sentence)
        logits = self.model.call(inputs=vector)
        prediction = tf.argmax(logits, axis=-1)
        prediction = prediction.numpy()[0]
        return reverse_classes[prediction]
コード例 #4
0
ファイル: train.py プロジェクト: yuedy/text_classifier
def train(data_manager, logger):
    embedding_dim = data_manager.embedding_dim
    num_classes = data_manager.max_label_number
    seq_length = data_manager.max_sequence_length

    train_file = classifier_config['train_file']
    dev_file = classifier_config['dev_file']
    train_df = pd.read_csv(train_file).sample(frac=1)

    if dev_file is '':
        # split the data into train and validation set
        train_df, dev_df = train_df[:int(len(train_df) * 0.9
                                         )], train_df[int(len(train_df) *
                                                          0.9):]
    else:
        dev_df = pd.read_csv(dev_file).sample(frac=1)

    train_dataset = data_manager.get_dataset(train_df, step='train')
    dev_dataset = data_manager.get_dataset(dev_df)

    vocab_size = data_manager.vocab_size

    embedding_method = classifier_config['embedding_method']
    if embedding_method == 'Bert':
        from transformers import TFBertModel
        bert_model = TFBertModel.from_pretrained(
            'bert-base-multilingual-cased')
    else:
        bert_model = None
    checkpoints_dir = classifier_config['checkpoints_dir']
    checkpoint_name = classifier_config['checkpoint_name']
    num_filters = classifier_config['num_filters']
    learning_rate = classifier_config['learning_rate']
    epoch = classifier_config['epoch']
    max_to_keep = classifier_config['max_to_keep']
    print_per_batch = classifier_config['print_per_batch']
    is_early_stop = classifier_config['is_early_stop']
    patient = classifier_config['patient']
    hidden_dim = classifier_config['hidden_dim']
    classifier = classifier_config['classifier']

    reverse_classes = {
        str(class_id): class_name
        for class_name, class_id in data_manager.class_id.items()
    }

    best_f1_val = 0.0
    best_at_epoch = 0
    unprocessed = 0
    batch_size = data_manager.batch_size
    very_start_time = time.time()
    loss_obj = FocalLoss() if classifier_config['use_focal_loss'] else None
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # 载入模型
    if classifier == 'textcnn':
        from engines.models.textcnn import TextCNN
        model = TextCNN(seq_length, num_filters, num_classes, embedding_dim,
                        vocab_size)
    elif classifier == 'textrcnn':
        from engines.models.textrcnn import TextRCNN
        model = TextRCNN(seq_length, num_classes, hidden_dim, embedding_dim,
                         vocab_size)
    elif classifier == 'textrnn':
        from engines.models.textrnn import TextRNN
        model = TextRNN(seq_length, num_classes, hidden_dim, embedding_dim,
                        vocab_size)
    else:
        raise Exception('config model is not exist')
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=checkpoints_dir,
        checkpoint_name=checkpoint_name,
        max_to_keep=max_to_keep)
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    logger.info(('+' * 20) + 'training starting' + ('+' * 20))
    for i in range(epoch):
        start_time = time.time()
        logger.info('epoch:{}/{}'.format(i + 1, epoch))
        for step, batch in tqdm(
                train_dataset.shuffle(
                    len(train_dataset)).batch(batch_size).enumerate()):
            if embedding_method == 'Bert':
                X_train_batch, y_train_batch = batch
                X_train_batch = bert_model(X_train_batch)[0]
            else:
                X_train_batch, y_train_batch = batch

            with tf.GradientTape() as tape:
                logits = model(X_train_batch, training=1)
                if classifier_config['use_focal_loss']:
                    loss_vec = loss_obj.call(y_true=y_train_batch,
                                             y_pred=logits)
                else:
                    loss_vec = tf.keras.losses.categorical_crossentropy(
                        y_true=y_train_batch, y_pred=logits)
                loss = tf.reduce_mean(loss_vec)
            # 定义好参加梯度的参数
            gradients = tape.gradient(loss, model.trainable_variables)
            # 反向传播,自动微分计算
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
            if step % print_per_batch == 0 and step != 0:
                predictions = tf.argmax(logits, axis=-1).numpy()
                y_train_batch = tf.argmax(y_train_batch, axis=-1).numpy()
                measures, _ = cal_metrics(y_true=y_train_batch,
                                          y_pred=predictions)
                res_str = ''
                for k, v in measures.items():
                    res_str += (k + ': %.3f ' % v)
                logger.info('training batch: %5d, loss: %.5f, %s' %
                            (step, loss, res_str))

        # validation
        logger.info('start evaluate engines...')
        y_true, y_pred = np.array([]), np.array([])
        loss_values = []

        for dev_batch in tqdm(dev_dataset.batch(batch_size)):
            if embedding_method == 'Bert':
                X_val_batch, y_val_batch = dev_batch
                X_val_batch = bert_model(X_val_batch)[0]
            else:
                X_val_batch, y_val_batch = dev_batch

            logits = model(X_val_batch)
            val_loss_vec = tf.keras.losses.categorical_crossentropy(
                y_true=y_val_batch, y_pred=logits)
            val_loss = tf.reduce_mean(val_loss_vec)
            predictions = tf.argmax(logits, axis=-1)
            y_val_batch = tf.argmax(y_val_batch, axis=-1)
            y_true = np.append(y_true, y_val_batch)
            y_pred = np.append(y_pred, predictions)
            loss_values.append(val_loss)

        measures, each_classes = cal_metrics(y_true=y_true, y_pred=y_pred)

        # 打印每一个类别的指标
        classes_val_str = ''
        for k, v in each_classes.items():
            if k in reverse_classes:
                classes_val_str += ('\n' + reverse_classes[k] + ': ' +
                                    str(each_classes[k]))
        logger.info(classes_val_str)
        # 打印损失函数
        val_res_str = 'loss: %.3f ' % np.mean(loss_values)
        for k, v in measures.items():
            val_res_str += (k + ': %.3f ' % measures[k])

        time_span = (time.time() - start_time) / 60

        logger.info('time consumption:%.2f(min), %s' %
                    (time_span, val_res_str))
        if measures['f1'] > best_f1_val:
            unprocessed = 0
            best_f1_val = measures['f1']
            best_at_epoch = i + 1
            checkpoint_manager.save()
            logger.info('saved the new best model with f1: %.3f' % best_f1_val)
        else:
            unprocessed += 1

        if is_early_stop:
            if unprocessed >= patient:
                logger.info(
                    'early stopped, no progress obtained within {} epochs'.
                    format(patient))
                logger.info('overall best f1 is {} at {} epoch'.format(
                    best_f1_val, best_at_epoch))
                logger.info('total training time consumption: %.3f(min)' %
                            ((time.time() - very_start_time) / 60))
                return
    logger.info('overall best f1 is {} at {} epoch'.format(
        best_f1_val, best_at_epoch))
    logger.info('total training time consumption: %.3f(min)' %
                ((time.time() - very_start_time) / 60))
コード例 #5
0
def train(data_manager, logger):
    embedding_dim = data_manager.embedding_dim
    num_classes = data_manager.max_label_number
    seq_length = data_manager.max_sequence_length

    checkpoints_dir = classifier_config['checkpoints_dir']
    checkpoint_name = classifier_config['checkpoint_name']
    num_filters = classifier_config['num_filters']
    learning_rate = classifier_config['learning_rate']
    epoch = classifier_config['epoch']
    max_to_keep = classifier_config['max_to_keep']
    print_per_batch = classifier_config['print_per_batch']
    is_early_stop = classifier_config['is_early_stop']
    patient = classifier_config['patient']
    hidden_dim = classifier_config['hidden_dim']
    classifier = classifier_config['classifier']

    best_f1_val = 0.0
    best_at_epoch = 0
    unprocessed = 0
    batch_size = data_manager.batch_size
    very_start_time = time.time()
    loss_obj = FocalLoss() if classifier_config['use_focal_loss'] else None
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    X_train, y_train, X_val, y_val = data_manager.get_training_set()
    # 载入模型
    if classifier == 'textcnn':
        from engines.models.textcnn import TextCNN
        model = TextCNN(seq_length, num_filters, num_classes, embedding_dim)
    elif classifier == 'textrcnn':
        from engines.models.textrcnn import TextRCNN
        model = TextRCNN(seq_length, num_classes, hidden_dim, embedding_dim)
    else:
        raise Exception('config model is not exist')
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=checkpoints_dir,
        checkpoint_name=checkpoint_name,
        max_to_keep=max_to_keep)
    num_iterations = int(math.ceil(1.0 * len(X_train) / batch_size))
    num_val_iterations = int(math.ceil(1.0 * len(X_val) / batch_size))
    logger.info(('+' * 20) + 'training starting' + ('+' * 20))
    for i in range(epoch):
        start_time = time.time()
        # shuffle train at each epoch
        sh_index = np.arange(len(X_train))
        np.random.shuffle(sh_index)
        X_train = X_train[sh_index]
        y_train = y_train[sh_index]
        logger.info('epoch:{}/{}'.format(i + 1, epoch))
        for iteration in tqdm(range(num_iterations)):
            X_train_batch, y_train_batch = data_manager.next_batch(
                X_train, y_train, start_index=iteration * batch_size)
            with tf.GradientTape() as tape:
                logits = model.call(X_train_batch, training=1)
                if classifier_config['use_focal_loss']:
                    loss_vec = loss_obj.call(y_true=y_train_batch,
                                             y_pred=logits)
                else:
                    loss_vec = tf.keras.losses.categorical_crossentropy(
                        y_true=y_train_batch, y_pred=logits)
                loss = tf.reduce_mean(loss_vec)
            # 定义好参加梯度的参数
            gradients = tape.gradient(loss, model.trainable_variables)
            # 反向传播,自动微分计算
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
            if iteration % print_per_batch == 0 and iteration != 0:
                predictions = tf.argmax(logits, axis=-1)
                y_train_batch = tf.argmax(y_train_batch, axis=-1)
                measures = cal_metrics(y_true=y_train_batch,
                                       y_pred=predictions)
                res_str = ''
                for k, v in measures.items():
                    res_str += (k + ': %.3f ' % v)
                logger.info('training batch: %5d, loss: %.5f, %s' %
                            (iteration, loss, res_str))

        # validation
        logger.info('start evaluate engines...')
        val_results = {'precision': 0, 'recall': 0, 'f1': 0}
        for iteration in tqdm(range(num_val_iterations)):
            X_val_batch, y_val_batch = data_manager.next_batch(
                X_val, y_val, iteration * batch_size)
            logits = model.call(X_val_batch)
            predictions = tf.argmax(logits, axis=-1)
            y_val_batch = tf.argmax(y_val_batch, axis=-1)
            measures = cal_metrics(y_true=y_val_batch, y_pred=predictions)
            for k, v in measures.items():
                val_results[k] += v

        time_span = (time.time() - start_time) / 60
        val_res_str = ''
        dev_f1_avg = 0
        for k, v in val_results.items():
            val_results[k] /= num_val_iterations
            val_res_str += (k + ': %.3f ' % val_results[k])
            if k == 'f1':
                dev_f1_avg = val_results[k]
        logger.info('time consumption:%.2f(min), %s' %
                    (time_span, val_res_str))

        if np.array(dev_f1_avg).mean() > best_f1_val:
            unprocessed = 0
            best_f1_val = np.array(dev_f1_avg).mean()
            best_at_epoch = i + 1
            checkpoint_manager.save()
            logger.info('saved the new best model with f1: %.3f' % best_f1_val)
        else:
            unprocessed += 1

        if is_early_stop:
            if unprocessed >= patient:
                logger.info(
                    'early stopped, no progress obtained within {} epochs'.
                    format(patient))
                logger.info('overall best f1 is {} at {} epoch'.format(
                    best_f1_val, best_at_epoch))
                logger.info('total training time consumption: %.3f(min)' %
                            ((time.time() - very_start_time) / 60))
                return
    logger.info('overall best f1 is {} at {} epoch'.format(
        best_f1_val, best_at_epoch))
    logger.info('total training time consumption: %.3f(min)' %
                ((time.time() - very_start_time) / 60))