示例#1
0
    def test_unit(self, text):
        if not os.path.exists(self.model_path):
            self.save_pb()
        graph = load_pb(self.model_path)
        sess = tf.Session(graph=graph)


        self.y = graph.get_operation_by_name("y").outputs[0]
        self.is_training = graph.get_operation_by_name("is_training").outputs[0]

        self.state = graph.get_tensor_by_name(self.output_nodes[-2]+":0")
        self.prob = graph.get_tensor_by_name("output/prob:0")
        self.predictions = graph.get_tensor_by_name("output/predictions:0")

        vocab_dict = embedding[self.embedding_type].build_dict(self.dict_path,mode = 'test')
        vocab_dict_rev = {vocab_dict[key]:key for key in vocab_dict}


        word, state = self.run(sess, graph, vocab_dict, vocab_dict_rev, None, text)
        len_idx = 0
        while word != '。':
            len_idx +=1
            if len_idx > 80: break
            text += word
            word, state = self.run(sess, graph, vocab_dict, vocab_dict_rev, 
                                   state, word)
        text += word
        logging.info(text)
示例#2
0
    def predict(self):
        predict_file = self.predict_path
        if not os.path.exists(self.model_path):
            self.save_pb()
        graph = load_pb(self.model_path)
        sess = tf.Session(graph=graph)

        self.y = graph.get_operation_by_name("y").outputs[0]
        self.is_training = graph.get_operation_by_name("is_training").outputs[0]

        #self.scores = graph.get_tensor_by_name(self.output_nodes+":0")
        self.scores = graph.get_tensor_by_name("output/scores:0")
        self.predictions = graph.get_tensor_by_name("output/predictions:0")

        vocab_dict = embedding[self.embedding_type].build_dict(self.dict_path,mode = 'test')
        mp, mp_rev = load_class_mp(self.classes_path) 
        with open(predict_file) as f:
            lines = [line.strip() for line in f.readlines()]
            batches = batch_iter(lines, self.batch_size, 1, shuffle=False)
            scores = []
            predicts = []
            for batch_x in batches:
                feed_dict = {
                    self.is_training: False
                }
                if not self.use_language_model:
                    preprocess_x, batch_x, len_batch = self.embedding.text2id(batch_x, vocab_dict)
                    feed_dict.update(self.embedding.pb_feed_dict(graph, batch_x, 'x'))
                    feed_dict.update(self.encoder.pb_feed_dict(graph, len = len_batch))
                else:
                    feed_dict.update(self.encoder.pb_feed_dict(graph, batch_x))
                predictions_out, scores_out = sess.run([self.predictions,
                                                            self.scores],
                                                            feed_dict=feed_dict)
                max_scores = [scores_out[idx][predictions_out[idx]] \
                              for idx in range(len(predictions_out))]

                predicts += list(predictions_out)
                scores += list(max_scores)

            predicts = [mp_rev[item] for item in predicts]

            dt = pd.DataFrame({'text': lines,
                               'pred': predicts,
                               'score': scores })
            dt.to_csv(self.predict_path+'.result.csv',index=False,sep=',')
示例#3
0
    def test_unit(self, text):
        if not os.path.exists(self.model_path):
            self.save_pb()
        graph = load_pb(self.model_path)
        sess = tf.Session(graph=graph)

        self.target = graph.get_operation_by_name("target_seq").outputs[0]
        self.is_training = graph.get_operation_by_name(
            "is_training").outputs[0]

        #self.state = graph.get_tensor_by_name(self.output_nodes[-2]+":0")
        self.final_state_decode = graph.get_tensor_by_name(
            self.output_nodes[-3] + ":0")
        self.final_state_encode = graph.get_tensor_by_name(
            self.output_nodes[-2] + ":0")
        self.prob = graph.get_tensor_by_name("output/prob:0")
        self.predictions = graph.get_tensor_by_name("output/predictions:0")

        vocab_dict = embedding[self.embedding_type].build_dict(self.dict_path,
                                                               mode='test')
        vocab_dict_rev = {vocab_dict[key]: key for key in vocab_dict}

        feed_dict = {self.is_training: False}
        preprocess_x, encode_batch, len_batch = self.embedding.text2id(
            [text], vocab_dict, self.maxlen)
        feed_dict.update(
            self.embedding.pb_feed_dict(graph, encode_batch, 'encode_seq'))
        feed_dict.update(
            self.encoder.pb_feed_dict(graph, len=(len_batch, None)))
        state = sess.run(self.final_state_encode, feed_dict=feed_dict)
        pdb.set_trace()
        state = state.tolist()
        word = '<s>'
        len_idx = 0
        while word != '。':
            len_idx += 1
            if len_idx > 80: break
            if word == '</s>': break
            text += word
            word, state = self.run(sess, graph, vocab_dict, vocab_dict_rev,
                                   state, word)
        text += word
        logging.info(text)
示例#4
0
    def test_unit(self, text):
        if not os.path.exists(self.model_path):
            self.save_pb()
        graph = load_pb(self.model_path)
        sess = tf.Session(graph=graph)

        self.y = graph.get_operation_by_name("y").outputs[0]
        self.is_training = graph.get_operation_by_name("is_training").outputs[0]

        self.scores = graph.get_tensor_by_name("output/scores:0")
        #self.scores = graph.get_tensor_by_name(self.output_nodes+":0")
        self.predictions = graph.get_tensor_by_name("output/predictions:0")

        vocab_dict = embedding[self.embedding_type].build_dict(self.dict_path,mode = 'test')
        mp, mp_rev = load_class_mp(self.classes_path) 
        batches = batch_iter([text], self.batch_size, 1, shuffle=False)
        for batch_x in batches:
            feed_dict = {
                self.is_training: False
            }
            if not self.use_language_model:
                preprocess_x, batch_x, len_batch = self.embedding.text2id(batch_x, vocab_dict)
                feed_dict.update(self.embedding.pb_feed_dict(graph, batch_x, 'x'))
                feed_dict.update(self.encoder.pb_feed_dict(graph, len = len_batch))
            else:
                feed_dict.update(self.encoder.pb_feed_dict(graph, batch_x))
            predictions_out, scores_out = sess.run([self.predictions,
                                                        self.scores],
                                                        feed_dict=feed_dict)
            max_scores = [scores_out[idx][predictions_out[idx]] \
                          for idx in range(len(predictions_out))]
        logging.info("preprocess: {}".format(preprocess_x))
        logging.info("class:{}, score:{}, class_id:{}".format(
            mp_rev[predictions_out[0]],
            max_scores[0],
            predictions_out[0]))
        return mp_rev[predictions_out[0]], max_scores[0]
示例#5
0
    def test(self):
        if not os.path.exists(self.model_path):
            self.save_pb()
        graph = load_pb(self.model_path)
        sess = tf.Session(graph=graph)

        self.y = graph.get_operation_by_name("y").outputs[0]
        self.is_training = graph.get_operation_by_name("is_training").outputs[0]
        self.accuracy = graph.get_operation_by_name("accuracy/accuracy").outputs[0]

        self.scores = graph.get_tensor_by_name("output/scores:0")
        #self.scores = graph.get_tensor_by_name(self.output_nodes+":0")
        self.predictions = graph.get_tensor_by_name("output/predictions:0")

        mp, mp_rev = load_class_mp(self.classes_path)

        test_x, test_y = self.load_data("test")
        pred_y = []
        scores = []
        batches = batch_iter(zip(test_x, test_y), self.batch_size, 1, shuffle=False)
        sum_accuracy, cnt = 0, 0
        right, all = 0, 0
        vocab_dict = embedding[self.embedding_type].build_dict(self.dict_path,
                                                      mode = 'test')
        all_test_x = []
        all_test_y = []
        for batch in batches:
            batch_x, batch_y = zip(*batch)

            feed_dict = {
                self.y: batch_y,
                self.is_training: False
            }
            if not self.use_language_model:
                preprocess_x, batch_x_id, len_batch = self.embedding.text2id(batch_x, vocab_dict, need_preprocess = True)
                feed_dict.update(self.embedding.pb_feed_dict(graph, batch_x_id, 'x'))
                feed_dict.update(self.encoder.pb_feed_dict(graph, len = len_batch))
            else:
                feed_dict.update(self.encoder.pb_feed_dict(graph, batch_x))
            accuracy_out, predictions_out, scores_out = sess.run([self.accuracy,
                                                                  self.predictions,
                                                                  self.scores],
                                                                 feed_dict=feed_dict)
            max_scores = [scores_out[idx][predictions_out[idx]] \
                          for idx in range(len(predictions_out))]
            sum_accuracy += accuracy_out
            cnt += 1
            pred_y += list(predictions_out)
            scores += list(max_scores)
            all_test_x += list(batch_x)
            all_test_y += list(batch_y)

            for idx in range(len(predictions_out)):
                if predictions_out[idx] == int(batch_y[idx]) and max_scores[idx]> self.thre_score:
                    right += 1
                all += 1
        dt = pd.DataFrame({'text': all_test_x,
                       'target': [mp_rev[int(item)] for item in
                                 all_test_y] ,
                       'pred': [mp_rev[item] for item in
                                pred_y],
                       'score': scores })
        dt.to_csv(self.test_path+'.result.csv',index=False,sep=',')
        logging.info("Test Accuracy : {0}".format(sum_accuracy / cnt))
        logging.info("Test Thre Accuracy : {0}".format(right / all))