def evaluate(self,data,sess):
        num_correct=0
        total_data=0
        data_idxs=range(len(data))
        test_batch_size=self.config.batch_size
        losses=[]
        for i in range(0,len(data),test_batch_size):
            batch_size = min(i+test_batch_size,len(data))-i
            if batch_size < test_batch_size:break
            batch_idxs=data_idxs[i:i+batch_size]
            batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size]
            labels_root=[l for _,l in batch_data]
            input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize)

            feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:1.0,self.batch_len:len(input_b)}

            pred_y=sess.run(self.pred,feed_dict=feed)
            #print pred_y,labels_root
            y=np.argmax(pred_y,axis=1)
            #num_correct+=np.sum(y==np.array(labels_root))
            for i,v in enumerate(labels_root):
                if y[i]==v:num_correct+=1
                total_data+=1
            #break
        print "total_data", total_data
        print "num_correct", num_correct
        acc=float(num_correct)/float(total_data)
        return acc
    def train(self,data,sess):
        from random import shuffle
        data_idxs=range(len(data))
        data_idxs.reverse()
        #shuffle(data_idxs)
        losses=[]
        for i in range(0,len(data),self.batch_size):
            batch_size = min(i+self.batch_size,len(data))-i
            if batch_size < self.batch_size:break

            batch_idxs=data_idxs[i:i+batch_size]
            batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size]

            input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize)

            feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:self.config.dropout,self.batch_len:len(input_b)}

            loss,bloss,_,_=sess.run([self.loss,self.batch_loss, self.train_op1,self.train_op2],feed_dict=feed)
            #sess.run(self.train_op,feed_dict=feed)
            #print np.mean(bloss)
            losses.append(loss)
            avg_loss=np.mean(losses)
            sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data))
            sys.stdout.write(sstr)
            sys.stdout.flush()

            #if i>1000: break
        return np.mean(losses)
Пример #3
0
    def train(self,data,sess):
        from random import shuffle
        data_idxs=range(len(data))
        shuffle(data_idxs)
        losses=[]
        for i in range(0,len(data),self.batch_size):
            batch_size = min(i+self.batch_size,len(data))-i
            if batch_size < self.batch_size:break

            batch_idxs=data_idxs[i:i+batch_size]
            batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size]

            input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize)

            feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:self.config.dropout,self.batch_len:len(input_b)}

            loss,_,_=sess.run([self.loss,self.train_op1,self.train_op2],feed_dict=feed)
            #sess.run(self.train_op,feed_dict=feed)

            losses.append(loss)
            avg_loss=np.mean(losses)
            sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data))
            sys.stdout.write(sstr)
            sys.stdout.flush()

            #if i>1000: break
        return np.mean(losses)
Пример #4
0
    def evaluate(self,data,sess):
        num_correct=0
        total_data=0
        data_idxs=range(len(data))
        test_batch_size=self.config.batch_size
        losses=[]
        for i in range(0,len(data),test_batch_size):
            batch_size = min(i+test_batch_size,len(data))-i
            if batch_size < test_batch_size:break
            batch_idxs=data_idxs[i:i+batch_size]
            batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size]
            labels_root=[l for _,l in batch_data]
            input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize)

            feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:1.0,self.batch_len:len(input_b)}

            pred_y=sess.run(self.pred,feed_dict=feed)
            #print pred_y,labels_root
            y=np.argmax(pred_y,axis=1)
            #num_correct+=np.sum(y==np.array(labels_root))
            for i,v in enumerate(labels_root):
                if y[i]==v:num_correct+=1
                total_data+=1
            #break

        acc=float(num_correct)/float(total_data)
        return acc
Пример #5
0
    def evaluate(self, data, sess):
        num_correct = 0
        num_correct_rel = 0
        total_rel = 0
        total_rel_find = 0
        total_data = 0
        data_idxs = range(len(data))
        test_batch_size = self.config.batch_size
        pred_res = []
        labels_res = []

        for i in range(0, len(data), test_batch_size):
            batch_size = min(i + test_batch_size, len(data)) - i
            if batch_size < test_batch_size: break
            batch_idxs = data_idxs[i:i + batch_size]
            batch_data = [data[ix] for ix in batch_idxs]  #[i:i+batch_size]
            input_b, treestr_b, labels_b, relstrT_b, relstrP_b, relstrM_b, nonrelstrT_b, nonrelstrP_b, nonrelstrM_b = extract_batch_tree_data(
                batch_data, self.config.maxnodesize, self.config.maxrelsize,
                self.config.maxnonrelTsize, self.config.maxnonrelPsize,
                self.config.maxnonrelMsize)

            feed = {
                self.input: input_b,
                self.treestr: treestr_b,
                self.labels: labels_b,
                self.relstrT: relstrT_b,
                self.relstrP: relstrP_b,
                self.relstrM: relstrM_b,
                self.nonrelstrT: nonrelstrT_b,
                self.nonrelstrP: nonrelstrP_b,
                self.nonrelstrM: nonrelstrM_b,
                self.dropout: 1.0,
                self.batch_len: len(input_b)
            }

            corrent_N, total_N, pred_val = sess.run(
                [self.corrent_num, self.total_num, self.pred], feed_dict=feed)

            # construct relation classification testset for T,P,M seperately.
            pred_list_T, pred_list_P, pred_list_M = self.gen_predlst(pred_val)

            # compute relation labels for T,P,M seperately.
            relT = relstrT_b.tolist()  #real relations
            relP = relstrP_b.tolist()
            relM = relstrM_b.tolist()
            labelsT = []
            labelsP = []
            labelsM = []

            for k, predlst_t in enumerate(pred_list_T):
                l = []
                for pred_rel in predlst_t:
                    if isRel(pred_rel, relT[k]):
                        l.append(1)
                    else:
                        l.append(0)
                labelsT.append(l)
            for k, predlst_p in enumerate(pred_list_P):
                l = []
                for pred_rel in predlst_p:
                    if isRel(pred_rel, relP[k]):
                        l.append(1)
                    else:
                        l.append(0)
                labelsP.append(l)
            for k, predlst_m in enumerate(pred_list_M):
                l = []
                for pred_rel in predlst_m:
                    if isRel(pred_rel, relM[k]):
                        l.append(1)
                    else:
                        l.append(0)
                labelsM.append(l)

# change predlist to numpy array
            dim1 = len(batch_data)
            dim3 = self.config.maxrelsize
            relstrT_arr = np.empty([dim1, dim3, 2], dtype='int32')
            relstrT_arr.fill(-1)
            relstrP_arr = np.empty([dim1, dim3, 2], dtype='int32')
            relstrP_arr.fill(-1)
            relstrM_arr = np.empty([dim1, dim3, 2], dtype='int32')
            relstrM_arr.fill(-1)
            for i in range(dim1):
                relstrT = np.array(pred_list_T[i], dtype='int32')
                if np.shape(relstrT)[0] != 0:
                    relstrT_arr[i, 0:len(relstrT), 0:2] = relstrT
                relstrP = np.array(pred_list_P[i], dtype='int32')
                if np.shape(relstrP)[0] != 0:
                    relstrP_arr[i, 0:len(relstrP), 0:2] = relstrP
                relstrM = np.array(pred_list_M[i], dtype='int32')
                if np.shape(relstrM)[0] != 0:
                    relstrM_arr[i, 0:len(relstrM), 0:2] = relstrM

            feed1 = {
                self.input: input_b,
                self.treestr: treestr_b,
                self.labels: labels_b,
                self.relstrT: relstrT_arr,
                self.relstrP: relstrP_arr,
                self.relstrM: relstrM_arr,
                self.nonrelstrT: nonrelstrT_b,
                self.nonrelstrP: nonrelstrP_b,
                self.nonrelstrM: nonrelstrM_b,
                self.dropout: 1.0,
                self.batch_len: dim1
            }
            rel_predT, rel_predP, rel_predM = sess.run(
                [self.predrelT, self.predrelP, self.predrelM], feed_dict=feed1)

            #compute correct numbers of relation
            for k in range(len(rel_predT)):
                for j in range(len(rel_predT[k])):
                    if rel_predT[k][j] == 1:
                        total_rel_find += 1
                        if labelsT[k][j] == 1:
                            num_correct_rel += 1

            for k in range(len(rel_predP)):
                for j in range(len(rel_predP[k])):
                    if rel_predP[k][j] == 1:
                        total_rel_find += 1
                        if labelsP[k][j] == 1:
                            num_correct_rel += 1

            for k in range(len(rel_predM)):
                for j in range(len(rel_predM[k])):
                    if rel_predM[k][j] == 1:
                        total_rel_find += 1
                        if labelsM[k][j] == 1:
                            num_correct_rel += 1

            mask = get_mask(pred_val, self.config.maxnodesize)
            pred_val = numpy_fillna(pred_val, mask)
            pred_val = np.reshape(pred_val, [-1])
            #print pred_val
            #print '\n'

            pred_res.extend(pred_val)
            labels_res.extend(np.reshape(labels_b, [-1]))
            num_correct += corrent_N
            total_data += total_N

        total_rel = getTotalrelnum(data)
        acc = float(num_correct) / float(total_data)
        relacc = float(num_correct_rel) / float(total_rel)
        return acc, relacc, pred_res, labels_res