Esempio n. 1
0
    def dev_one_epoch(self, sess, dev):
        """

        :param sess:
        :param dev:
        :return:
        """
        batches = dev_batch_iter(dev, self.batch_size, shuffle=False)
        true_label_list, pre_label_list = [], []
        for step, dev in enumerate(batches):

            dev_x, dev_p1, dev_p2, dev_y = zip(*dev)
            dev_feed_dict = {
                self.target_word_ids: dev_x,
                self.target_position_ids1: dev_p1,
                self.target_position_ids2: dev_p2,
                self.relation_labels: dev_y
            }
            # [batch,]
            label_pre = sess.run(self.labels_predict, feed_dict=dev_feed_dict)
            # [batch,]
            label_true = np.argmax(dev_y, axis=-1)

            true_label_list.extend(label_true)
            pre_label_list.extend(label_pre)
        print("true label:" + str(true_label_list))
        print("pre label:" + str(pre_label_list))

        return true_label_list, pre_label_list
Esempio n. 2
0
    def dev_one_epoch(self, sess, dev, epoch):
        """

        :param sess:
        :param dev:
        :return:
        """
        batches = dev_batch_iter(dev, self.batch_size, shuffle=False)
        true_label_list,pre_label_list = [],[]

        for step, dev in enumerate(batches):

            dev_x, dev_p1, dev_p2, dev_y = zip(*dev)
            dev_feed_dict = {self.source_word_ids: dev_x,
                             self.source_position_ids1: dev_p1,
                             self.source_position_ids2: dev_p2,
                             self.relation_labels: dev_y,
                             self.dropout:1.0}
            # [batch,]
            label_pre = sess.run(self.labels_predict,feed_dict=dev_feed_dict)
            #print(label_pre)
            # [batch,]
            label_true = np.argmax(dev_y, axis=-1)

            true_label_list.extend(label_true)
            #label_pre = label_pre.to_list()
            pre_label_list.extend(label_pre)

        print("true label:"+str(true_label_list))
        print("pre label:"+str(pre_label_list))

        macro_f1 = f1_score(true_label_list, pre_label_list, average="macro")
        self.logger.info('epoch {}, macro_f1: {:.4}'
                         .format(epoch + 1,macro_f1))
Esempio n. 3
0
File: GSN.py Progetto: bobobe/backup
    def dev_one_epoch(self, sess, dev, epoch):
        """

        :param sess:
        :param dev:
        :return:
        """
        batches = dev_batch_iter(dev, self.batch_size, shuffle=False)
        true_label_list, pre_label_list = [], []
        dev_acc = 0
        dev_los = 0
        num = 0
        num_batches = (len(dev) + self.batch_size - 1) // self.batch_size
        for step, dev in enumerate(batches):
            step_num = epoch * num_batches + step + 1

            dev_x, dev_p1, dev_p2, dev_y = zip(*dev)
            dev_feed_dict = {
                self.target_word_ids: dev_x,
                self.target_position_ids1: dev_p1,
                self.target_position_ids2: dev_p2,
                self.relation_labels: dev_y,
                self.dropout_kp: 1.0
            }
            # [batch,]
            label_pre, dev_loss, dev_accuracy, summary = sess.run(
                [
                    self.labels_predict, self.dev_loss, self.dev_accuracy,
                    self.dev_merged
                ],
                feed_dict=dev_feed_dict)
            # [batch,]
            label_true = np.argmax(dev_y, axis=-1)

            self.file_writer.add_summary(summary, step_num)
            print(label_pre)
            true_label_list.extend(label_true)
            pre_label_list.extend(label_pre)
            dev_acc += dev_accuracy
            dev_los += dev_loss
            num += 1
        dev_acc = dev_acc / (num + 1.0)
        dev_los = dev_los / (num + 1.0)
        print("true label:" + str(true_label_list))
        print("pre label:" + str(pre_label_list))

        macro_f1 = f1_score(true_label_list, pre_label_list, average="macro")
        self.logger.info(
            'epoch {},  dev_loss: {:.4}, dev_accuracy: {:.4},macro_f1: {:.4}'.
            format(epoch + 1, dev_los, dev_acc, macro_f1))
        return macro_f1