示例#1
0
    def predict(self, seqs, demo=True):
        if demo:
            input_X = dev2vec(seqs,
                              word_dict=self.vocab,
                              max_seq_len=self.sequence_length)
        else:
            input_X, _ = pad_sequences(seqs)

        inputs = torch.LongTensor(input_X)
        outputs = self(inputs)
        return torch.argmax(outputs, dim=1).numpy()
示例#2
0
    def predict_prob(self, seqs, demo=True):
        if demo:
            input_X = dev2vec(seqs,
                              word_dict=self.vocab,
                              max_seq_len=self.sequence_length)
        else:
            input_X, _ = pad_sequences(seqs)

        inputs = torch.LongTensor(input_X)
        outputs = torch.exp(self(inputs))
        # softmax
        outputs_softmax = outputs / torch.Tensor.reshape(
            torch.sum(outputs, dim=1), [-1, 1])
        return outputs_softmax.detach().numpy()
示例#3
0
    def predict_prob(self, sess, seqs, demo=True):
        """预测概率"""
        if demo:
            input_X = dev2vec(seqs,
                              word_dict=self.vocab,
                              max_seq_len=self.sequence_length)
        else:
            input_X, _ = pad_sequences(seqs)
        possibility = sess.run(self.possibility,
                               feed_dict={
                                   self.input_x: input_X,
                                   self.dropout_kp: 1.0
                               })

        return possibility
示例#4
0
 def predict(self, sess, seqs, demo=True):
     """预测标签"""
     if demo:
         input_X = dev2vec(seqs,
                           word_dict=self.vocab,
                           max_seq_len=self.sequence_length)
     else:
         input_X, _ = pad_sequences(seqs)
     predictions = sess.run(self.predictions,
                            feed_dict={
                                self.input_x: input_X,
                                self.keep_prob: 1.0,
                                self.tst: True
                            })
     return predictions
示例#5
0
    def train(self, sess, train, dev, shuffle=True, re_train=False):
        checkpoints_path = None
        saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
        # DEV split
        dev_batches = list(
            batch_yield(dev,
                        1000,
                        self.vocab,
                        self.tag2label,
                        max_seq_len=self.sequence_length,
                        shuffle=shuffle))

        # with tf.compat.v1.Session(config=self.config) as sess:
        if not re_train:
            sess.run(tf.compat.v1.global_variables_initializer())

        self.merged = tf.compat.v1.summary.merge_all()
        train_writer = tf.compat.v1.summary.FileWriter(
            self.model_path + os.sep + "summaries" + os.sep + 'train',
            sess.graph)
        test_writer = tf.compat.v1.summary.FileWriter(self.model_path +
                                                      os.sep + "summaries" +
                                                      os.sep + 'test')

        for epoch in range(self.eopches):
            num_batches = (len(train) + self.batch_size - 1) // self.batch_size
            st = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

            # 已经完成 token -> id 的转化
            batches = batch_yield(train,
                                  self.batch_size,
                                  self.vocab,
                                  self.tag2label,
                                  max_seq_len=self.sequence_length,
                                  shuffle=shuffle)

            for step, (seqs, labels) in enumerate(batches):
                b_x, b_len_x = pad_sequences(
                    seqs, max_sequence_length=self.sequence_length)
                b_y = labels  # PADDING
                sys.stdout.write(' processing: {} batch / {} batches.'.format(
                    step + 1, num_batches) + '\r')
                step_num = epoch * num_batches + step + 1

                summary, loss, acc, _ = sess.run(
                    [self.merged, self.loss_val, self.accuracy, self.opt],
                    feed_dict={
                        self.input_x: b_x,
                        self.input_y: b_y,
                        self.dropout_kp: 1 - self.keep_rate
                    })

                train_writer.add_summary(summary, step_num)
                if step + 1 == 1 or (step +
                                     1) % 100 == 0 or step + 1 == num_batches:
                    logger.info(
                        '{} <TRAIN> epoch: {}, step: {}, loss: {:.4}, global_step: {}, acc: {}'
                        .format(st, epoch + 1, step + 1, loss, step_num, acc))

                if step + 1 == num_batches:
                    checkpoints_path = os.path.join(self.model_path,
                                                    "checkpoints")
                    if not os.path.exists(checkpoints_path):
                        os.makedirs(checkpoints_path)
                    saver.save(sess,
                               checkpoints_path + os.sep + "model",
                               global_step=step_num)

            # DEV
            logger.info(
                '======================validation / test======================'
            )
            _step = (epoch + 1) * num_batches
            y_trues, y_preds = [], []
            tmp_loss, tmp_acc = [], []
            for dev_step, (dev_X, dev_y) in tqdm(enumerate(dev_batches)):
                if dev_step == 0:
                    test_summary, test_loss, test_acc, y_pred = \
                        sess.run([self.merged, self.loss_val, self.accuracy, self.predictions],
                                 feed_dict={self.input_x: dev_X,
                                            self.input_y: dev_y,
                                            self.dropout_kp: 1.0, })
                    test_writer.add_summary(test_summary, _step)
                else:
                    test_loss, test_acc, y_pred = sess.run(
                        [self.loss_val, self.accuracy, self.predictions],
                        feed_dict={
                            self.input_x: dev_X,
                            self.input_y: dev_y,
                            self.dropout_kp: 1.0,
                        })
                y_trues.extend(dev_y)
                y_preds.extend(y_pred)
                tmp_loss.append(test_loss)
                tmp_acc.append(test_acc)

            logger.info(
                "{} <DEV> epoch: {} | step: {} | loss:{} | acc: {} ".format(
                    st, epoch + 1, _step, np.average(tmp_loss),
                    np.average(tmp_acc)))
            print(
                classification_report(y_trues,
                                      y_preds,
                                      target_names=self.target_names))

        logger.info("model save in {}".format(checkpoints_path))