예제 #1
0
    def train(self,data,sess,isTree=True):

        from random import shuffle
        shuffle(data)
        losses=[]
        for i in range(0,len(data),self.batch_size):
            batch_size = min(i+self.batch_size,len(data))-i
            batch_data=data[i:i+batch_size]

            seqdata,seqlabels,seqlngths,max_len=extract_seq_data(batch_data
                                                         ,self.internal,self.config.maxseqlen)
            feed={self.input:seqdata,self.labels:seqlabels,
                  self.dropout:self.config.dropout,self.lngths:
                  seqlngths,self.batch_len:len(seqdata),self.max_time:max_len}
            #loss,_=sess.run([self.loss,self.train_op],feed_dict=feed)
            loss,_,_=sess.run([self.loss,self.train_op1,self.train_op2],feed_dict=feed)
            #sess.run(self.train_op,feed_dict=feed)

            losses.append(loss)
            avg_loss=np.mean(losses)
            sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data))
            sys.stdout.write(sstr)
            sys.stdout.flush()
            #if i>100: break
        return np.mean(losses)
예제 #2
0
    def evaluate(self, data, sess):
        num_correct = 0
        total_data = 0
        for i in range(0, len(data), self.batch_size):
            batch_size = min(i + self.batch_size, len(data)) - i
            batch_data = data[i:i + batch_size]

            seqdata, seqlabels, seqlngths, max_len = extract_seq_data(
                batch_data, 0, self.config.maxseqlen)
            feed = {
                self.input: seqdata,
                self.labels: seqlabels,
                self.dropout: 1.0,
                self.lngths: seqlngths,
                self.batch_len: len(seqdata),
                self.max_time: max_len
            }
            pred = sess.run(self.pred, feed_dict=feed)
            y = np.argmax(pred, axis=1)
            #print y,seqlabels,pred
            #print y,seqlabels,pred
            for i, v in enumerate(y):
                if seqlabels[i] == v:
                    num_correct += 1
                total_data += 1
        acc = float(num_correct) / float(total_data)
        return acc
예제 #3
0
    def train(self,data,sess,isTree=True):

        from random import shuffle
        shuffle(data)
        losses=[]
        for i in range(0,len(data),self.batch_size):
            batch_size = min(i+self.batch_size,len(data))-i
            batch_data=data[i:i+batch_size]

            seqdata,seqlabels,seqlngths,max_len=extract_seq_data(batch_data
                                                         ,self.internal,self.config.maxseqlen)
            feed={self.input:seqdata,self.labels:seqlabels,
                  self.dropout:self.config.dropout,self.lngths:
                  seqlngths,self.batch_len:len(seqdata),self.max_time:max_len}
            #loss,_=sess.run([self.loss,self.train_op],feed_dict=feed)
            loss,_,_=sess.run([self.loss,self.train_op1,self.train_op2],feed_dict=feed)
            #sess.run(self.train_op,feed_dict=feed)

            losses.append(loss)
            avg_loss=np.mean(losses)
            sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data))
            sys.stdout.write(sstr)
            sys.stdout.flush()
            #if i>100: break
        return np.mean(losses)
예제 #4
0
    def evaluate(self,data,sess):
        num_correct=0
        total_data=0
        for i in range(0,len(data),self.batch_size):
            batch_size = min(i+self.batch_size,len(data))-i
            batch_data=data[i:i+batch_size]

            seqdata,seqlabels,seqlngths,max_len=extract_seq_data(batch_data
                                        ,0,self.config.maxseqlen)
            feed={self.input:seqdata,self.labels:seqlabels,
                  self.dropout:1.0,self.lngths:
                  seqlngths,self.batch_len:len(seqdata),self.max_time:max_len}
            pred=sess.run(self.pred,feed_dict=feed)
            y=np.argmax(pred,axis=1)
            #print y,seqlabels,pred
            #print y,seqlabels,pred
            for i,v in enumerate(y):
                if seqlabels[i]==v:
                    num_correct+=1
                total_data+=1
        acc=float(num_correct)/float(total_data)
        return acc