コード例 #1
0
ファイル: network.py プロジェクト: wangtingc/DilatedPixelCNN
 def train(self):
     if self.conf.reload_epoch > 0:
         self.reload(self.conf.reload_epoch)
     train_reader = H5DataLoader(self.conf.data_dir+self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir+self.conf.valid_data)
     for epoch_num in range(self.conf.max_epoch):
         if epoch_num % self.conf.test_step == 1:
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.annotations: annotations}
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_epoch)
             print('----testing loss', loss)
         elif epoch_num % self.conf.summary_step == 1:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.annotations: annotations}
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_epoch)
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.annotations: annotations}
             loss, _ = self.sess.run(
                 [self.loss_op, self.train_op], feed_dict=feed_dict)
             print('----training loss', loss)
         if epoch_num % self.conf.save_step == 1:
             self.save(epoch_num+self.conf.reload_epoch)
コード例 #2
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)
     for epoch_num in range(self.conf.max_step + 1):
         if epoch_num and epoch_num % self.conf.test_interval == 0:
             inputs, labels = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
             print('----testing loss', loss)
         if epoch_num and epoch_num % self.conf.summary_interval == 0:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, _ = self.sess.run([self.loss_op, self.train_op],
                                     feed_dict=feed_dict)
             print("Epoch", epoch_num, '----training loss', loss)
         if epoch_num and epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num + self.conf.reload_step)
コード例 #3
0
ファイル: u_net.py プロジェクト: moliqingcha/Deformable-U-Net
 def train(self):
     
     # 有时可以从以训练好的model开始训练
     if self.conf.reload_epoch > 0:
         self.reload(self.conf.reload_epoch)
         
     # 读取数据
     train_reader = H5DataLoader(self.conf.data_dir+self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir+self.conf.valid_data)
     
     # 记录loss
     valid_loss_list = []
     train_loss_list = []
     
     for epoch_num in range(self.conf.max_epoch):
         if epoch_num % self.conf.test_step == 1:
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.annotations: annotations}
             loss, summary = self.sess.run([self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num)
            
             print(epoch_num, '----testing loss', loss)
             print(epoch_num)
             
             # 记录验证集上的loss
             valid_loss_list.append(loss)
             np.save(self.conf.record_dir+"valid_loss.npy",np.array(valid_loss_list))
         elif epoch_num % self.conf.summary_step == 1:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.annotations: annotations}
             loss, _, summary = self.sess.run([self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num)
             print(epoch_num)
             
             # 记录训练集上的loss
             train_loss_list.append(loss)
             np.save(self.conf.record_dir+"train_loss.npy",np.array(train_loss_list))
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.annotations: annotations}
             loss,_ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict)
             
             print(epoch_num)
             
             # 记录训练集上的loss
             train_loss_list.append(loss)
             np.save(self.conf.record_dir+"train_loss.npy",np.array(train_loss_list))
     
         if epoch_num % self.conf.save_step == 1:
             self.save(epoch_num)
コード例 #4
0
ファイル: network.py プロジェクト: wangtingc/DilatedPixelCNN
 def test(self):
     print('---->testing ', self.conf.test_epoch)
     if self.conf.test_epoch > 0:
         self.reload(self.conf.test_epoch)
     else:
         print("please set a reasonable test_epoch")
         return
     test_reader = H5DataLoader(
         self.conf.data_dir+self.conf.test_data, False)
     self.sess.run(tf.local_variables_initializer())
     count = 0
     losses = []
     accuracies = []
     m_ious = []
     while True:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         loss, accuracy, m_iou, _ = self.sess.run(
             [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
             feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)
         count += 1
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
     print('Loss: ', np.mean(losses))
     print('Accuracy: ', np.mean(accuracies))
     print('M_iou: ', m_ious[-1])
コード例 #5
0
ファイル: u_net.py プロジェクト: moliqingcha/Deformable-U-Net
 def test(self,model_i):
     print('---->testing ', model_i)
     
     # 加载模型
     if model_i > 0:
         self.reload(model_i)
     else:
         print("please set a reasonable test_epoch")
         return
     
     # 读取数据,注意是False,代表不是在训练
     valid_reader = H5DataLoader(self.conf.data_dir+self.conf.valid_data,False)
     self.sess.run(tf.local_variables_initializer())
    
     # 记录测试参数
     losses = []
     accuracies = []
     m_ious = []
     while True:
         inputs, annotations = valid_reader.next_batch(self.conf.batch)
        
         # 终止条件:当取出的batch不够个数了就break
         if inputs.shape[0] < self.conf.batch:
             break
             
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         loss, accuracy, m_iou, _ = self.sess.run([self.loss_op, self.accuracy_op, self.m_iou, self.miou_op], feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)          
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
         
         # 其实是每一个batch上计算一次指标,最后求均值
         
     return np.mean(losses),np.mean(accuracies),m_ious[-1]
コード例 #6
0
 def train(self):
     if self.conf.reload_epoch > 0:
         self.reload(self.conf.reload_epoch)
     data_reader = H5DataLoader('../../Data/data/celeba_train_test.h5')
     epoch_num = 1
     epoch = 0
     while epoch < self.conf.max_epoch:
         if epoch_num % self.conf.test_step == 1:
             inputs = np.zeros(self.input_shape)
             feed_dict = {self.inputs: inputs, self.valid: True}
             summary = self.sess.run(self.valid_summary,
                                     feed_dict=feed_dict)
             self.save_summary(summary, epoch_num)
         elif epoch_num % self.conf.summary_step == 1:
             targets = data_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: targets, self.valid: False}
             _, summary = self.sess.run(
                 [self.train_opg, self.train_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num)
         else:
             targets = data_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: targets, self.valid: False}
             rec_loss_l, kl_loss, _ = self.sess.run(
                 [self.rec_loss_l, self.kl_loss, self.train_opg],
                 feed_dict=feed_dict)
             print('epoch = ', epoch, '----training loss l=', rec_loss_l,
                   ', kl= ', kl_loss)
         if epoch_num % self.conf.save_step == 0:
             self.save(epoch_num)
         epoch = data_reader.epoch
         epoch_num += 1
コード例 #7
0
 def predict(self):
     print('---->predicting ', self.conf.test_epoch)
     if self.conf.test_epoch > 0:
         self.reload(self.conf.test_epoch)
     else:
         print("please set a reasonable test_epoch")
         return
     test_reader = H5DataLoader(
         self.conf.data_dir+self.conf.test_data, False)
     self.sess.run(tf.local_variables_initializer())
     predictions = []
     losses = []
     accuracies = []
     m_ious = []
     count=0
     while True:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         loss, accuracy, m_iou, _ = self.sess.run(
             [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
             feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
         predictions.append(self.sess.run(
             self.decoded_predictions, feed_dict=feed_dict))
     print('----->saving predictions')
     for index, prediction in enumerate(predictions):
         for i in range(prediction.shape[0]):
             imsave(prediction[i], self.conf.sample_dir +
                    str(index*prediction.shape[0]+i)+'.png')
     return np.mean(losses),np.mean(accuracies),m_ious[-1]
コード例 #8
0
    def predict(self):
        print('---->predicting ', self.conf.test_epoch)

        if self.conf.test_epoch > 0:
            self.reload(self.conf.test_epoch)
        else:
            print("please set a reasonable test_epoch")
            return

        # 读取数据
        test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                   False)
        self.sess.run(tf.local_variables_initializer())
        predictions = []
        losses = []
        accuracies = []
        m_ious = []

        while True:
            inputs, annotations = test_reader.next_batch(self.conf.batch)

            # 终止条件
            if inputs.shape[0] < self.conf.batch:
                break

            feed_dict = {self.inputs: inputs, self.annotations: annotations}
            loss, accuracy, m_iou, _ = self.sess.run(
                [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
                feed_dict=feed_dict)
            print('values----->', loss, accuracy, m_iou)
            # 记录指标
            losses.append(loss)
            accuracies.append(accuracy)
            m_ious.append(m_iou)
            # 记录预测值
            predictions.append(
                self.sess.run(self.decoded_predictions, feed_dict=feed_dict))

        print('----->saving predictions')
        print(np.shape(predictions))
        num = 0
        for index, prediction in enumerate(predictions):

            # 下面的程序用于输出一通道的预测值,测试时需要观察的
            #print(prediction.shape)
            #print(index)
            #np.save("pred",np.array(prediction))

            # 把一通道的预测值保存为三通道图片,这是自己写的函数
            for i in range(prediction.shape[0]):
                np.save(self.conf.sample_dir + "pred" + str(num) + ".npy",
                        prediction[i])
                num += 1
                imsave(
                    prediction[i], self.conf.sample_dir +
                    str(index * prediction.shape[0] + i) + '.png')

        # 验证和测试的时候,指标都是返回的全体上的均值
        return np.mean(losses), np.mean(accuracies), m_ious[-1]
コード例 #9
0
ファイル: actions.py プロジェクト: moliqingcha/DACN
    def test(self, model_i):

        print('---->testing ', model_i)

        if model_i > 0:
            self.reload(model_i)
        else:
            print("please set a reasonable test_epoch")
            return

        valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data,
                                    False)
        self.sess.run(tf.local_variables_initializer())

        losses = []
        accuracies = []
        m_ious = []
        dices = []
        count = 0
        while True:
            inputs, annotations = valid_reader.next_batch(self.conf.batchsize)

            if inputs.shape[0] < self.conf.batch:
                break

            feed_dict = {
                self.inputs: inputs,
                self.annotations: annotations,
                self.is_train: False
            }
            loss, accuracy, m_iou, _ = self.sess.run(
                [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
                feed_dict=feed_dict)
            print(count)
            print('values----->', loss, accuracy, m_iou)
            losses.append(loss)
            accuracies.append(accuracy)
            m_ious.append(m_iou)

            out, gt = self.sess.run([self.out, self.gt], feed_dict=feed_dict)

            if self.conf.class_num == 2:
                tp = np.sum(out * gt)
                fenmu = np.sum(out) + np.sum(gt) + 0.000001
                dice = 2 * tp / fenmu
                dices.append(dice)

            print('dice----->', dice)
            count += 1
            if count == self.conf.valid_num:
                break

        return np.mean(losses), np.mean(accuracies), m_ious[-1], np.mean(dices)
コード例 #10
0
 def test(self):
     print('---->testing ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                False)
     accuracies = []
     while True:
         inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs is None or inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         accur = self.sess.run(self.accuracy_op, feed_dict=feed_dict)
         accuracies.append(accur)
     print('accuracy is ', sum(accuracies) / len(accuracies))
コード例 #11
0
ファイル: actions.py プロジェクト: yanlong-sun/DACN_package
    def predict(self):
        self.reload(self.conf.test_epoch)
        test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                   False)
        self.sess.run(tf.local_variables_initializer())
        predictions = []
        net_predictions = []
        losses = []
        dices = []
        accuracies = []
        count = 0

        while True:
            inputs, annotations = test_reader.next_batch(self.conf.batchsize)
            if inputs.shape[0] < self.conf.batch:
                break

            feed_dict = {
                self.inputs: inputs,
                self.annotations: annotations,
                self.is_train: False
            }

            predictions.append(
                self.sess.run(
                    self.predictions,
                    feed_dict=feed_dict))  # <-------decoded_predictions
            net_predictions.append(
                self.sess.run(self.decoded_net_pred, feed_dict=feed_dict))

            count += 1
            if count == self.conf.test_num:
                break

        num = 0
        for index, prediction in enumerate(net_predictions):

            for i in range(prediction.shape[0]):
                num += 1
                imsave(
                    prediction[i], self.conf.sample_net_dir +
                    str(index * prediction.shape[0] + i) + '.png')

        return np.mean(losses), np.mean(accuracies), np.mean(dices)
コード例 #12
0
 def output_train(self):
     print('---->output', self.conf.test_epoch)
     if self.conf.test_epoch > 0:
         self.reload(self.conf.test_epoch)
     else:
         print("please set a reasonable test_epoch")
         return
     data_reader = H5DataLoader('../../Data/data/celeba_train_test.h5')
     targets = data_reader.next_batch(self.conf.batch)
     feed_dict = {self.inputs: targets, self.valid: False}
     predictions, super_predictions = self.sess.run(
         [self.predictions, self.super_predictions], feed_dict=feed_dict)
     super_predictions = np.array(super_predictions)
     predictions = np.array(predictions)
     for i in range(predictions.shape[0]):
         imsave(self.conf.sample_dir + str(i) + '_o.png',
                np.reshape(targets[i], (128, 128, 3)))
         #imsave(self.conf.sample_dir + str(i) + 'l.png', np.reshape(predictions[i], (128,128,3)))
         imsave(self.conf.sample_dir + str(i) + '_h.png',
                np.reshape(super_predictions[i], (128, 128, 3)))
コード例 #13
0
ファイル: network.py プロジェクト: wangtingc/DilatedPixelCNN
 def predict(self):
     print('---->predicting ', self.conf.test_epoch)
     if self.conf.test_epoch > 0:
         self.reload(self.conf.test_epoch)
     else:
         print("please set a reasonable test_epoch")
         return
     test_reader = H5DataLoader(
         self.conf.data_dir+self.conf.test_data, False)
     predictions = []
     while True:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         predictions.append(self.sess.run(
             self.decoded_predictions, feed_dict=feed_dict))
     print('----->saving predictions')
     for index, prediction in enumerate(predictions):
         for i in range(prediction.shape[0]):
             imsave(prediction[i], self.conf.sample_dir +
                    str(index*prediction.shape[0]+i)+'.png')
コード例 #14
0
ファイル: actions.py プロジェクト: moliqingcha/ASCNet
    def train(self):

        # 用于记录summary
        self.train_summary = self.config_summary('train')
        self.valid_summary = self.config_summary('valid')

        # 有时可以从已训练好的model开始训练
        if self.conf.reload_epoch > 0:
            self.reload(self.conf.reload_epoch)

        # 读取数据
        train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
        valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)

        # 记录loss
        valid_loss_list = []
        train_loss_list = []

        # 记录accuracy
        train_acc_list = []
        valid_acc_list = []

        # 记录m_iou
        train_miou_list = []
        valid_miou_list = []

        # 初始化局部变量是为了保存训练中的 miou, 因为这是个局部变量
        self.sess.run(tf.local_variables_initializer())

        # 开始训练
        for epoch_num in range(self.conf.max_epoch):

            # 训练到test_step,在验证集上进行一次验证
            if epoch_num % self.conf.test_step == 1:
                inputs, annotations = valid_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: False
                }
                #loss, summary = self.sess.run([self.loss_op, self.valid_summary], feed_dict=feed_dict)
                loss, accuracy, m_iou, _ = self.sess.run(
                    [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
                    feed_dict=feed_dict)
                #self.save_summary(summary, epoch_num)

                print(epoch_num, '----valid loss', loss)

                # 记录验证集上的loss
                valid_loss_list.append(loss)
                np.save(self.conf.record_dir + "valid_loss.npy",
                        np.array(valid_loss_list))
                # 记录验证集上的acc
                valid_acc_list.append(accuracy)
                np.save(self.conf.record_dir + "valid_acc.npy",
                        np.array(valid_acc_list))
                # 记录验证集上的miou
                valid_miou_list.append(m_iou)
                np.save(self.conf.record_dir + "valid_miou.npy",
                        np.array(valid_miou_list))

                ################################### 还是要做训练的呀 #######################################
                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: True
                }
                haha, loss, accuracy, m_iou, _ = self.sess.run(
                    [
                        self.train_op, self.loss_op, self.accuracy_op,
                        self.m_iou, self.miou_op
                    ],
                    feed_dict=feed_dict)

                print(epoch_num, '----train loss', loss)

                # 记录训练集上的loss
                train_loss_list.append(loss)
                np.save(self.conf.record_dir + "train_loss.npy",
                        np.array(train_loss_list))
                # 记录训练集上的acc
                train_acc_list.append(accuracy)
                np.save(self.conf.record_dir + "train_acc.npy",
                        np.array(train_acc_list))
                # 记录训练集上的miou
                train_miou_list.append(m_iou)
                np.save(self.conf.record_dir + "train_miou.npy",
                        np.array(train_miou_list))

            elif epoch_num % self.conf.summary_step == 1:
                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: False
                }
                #loss, _, summary = self.sess.run([self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict)
                #self.save_summary(summary, epoch_num)
                #print(epoch_num)

                # 记录训练集上的loss
                #train_loss_list.append(loss)
                #np.save(self.conf.record_dir+"train_loss.npy",np.array(train_loss_list))
            else:

                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: True
                }
                loss, _ = self.sess.run([self.loss_op, self.train_op],
                                        feed_dict=feed_dict)

                print(epoch_num)

            # 保存模型
            if epoch_num % self.conf.save_step == 1:
                self.save(epoch_num)
コード例 #15
0
ファイル: actions.py プロジェクト: moliqingcha/DACN
    def predict(self):

        print('---->predicting ', self.conf.test_epoch)

        if self.conf.test_epoch > 0:
            self.reload(self.conf.test_epoch)
        else:
            print("please set a reasonable test_epoch")
            return

        test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                   False)
        self.sess.run(tf.local_variables_initializer())
        predictions = []
        net_predictions = []
        outputs = []
        probabilitys = []
        losses = []
        accuracies = []
        m_ious = []

        rate_list = []
        befores = []
        afters = []
        maps = []
        start_maps = []
        count = 0

        while True:
            inputs, annotations = test_reader.next_batch(self.conf.batchsize)

            if inputs.shape[0] < self.conf.batch:
                break

            feed_dict = {
                self.inputs: inputs,
                self.annotations: annotations,
                self.is_train: False
            }
            loss, accuracy, m_iou, _ = self.sess.run(
                [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
                feed_dict=feed_dict)
            print('values----->', loss, accuracy, m_iou)

            losses.append(loss)
            accuracies.append(accuracy)
            m_ious.append(m_iou)

            predictions.append(
                self.sess.run(self.decoded_predictions, feed_dict=feed_dict))
            net_predictions.append(
                self.sess.run(self.decoded_net_pred, feed_dict=feed_dict))
            outputs.append(self.sess.run(self.outputs, feed_dict=feed_dict))

            count += 1
            if count == self.conf.test_num:
                break

        print('----->saving outputs')
        print(np.shape(probabilitys))
        np.save(self.conf.sample_dir + "outputs" + ".npy", np.array(outputs))

        print('----->saving predictions')
        print(np.shape(predictions))
        num = 0
        for index, prediction in enumerate(predictions):

            for i in range(prediction.shape[0]):
                np.save(self.conf.sample_dir + "pred" + str(num) + ".npy",
                        prediction[i])
                num += 1
                imsave(
                    prediction[i], self.conf.sample_dir +
                    str(index * prediction.shape[0] + i) + '.png')

        print('----->saving net_predictions')
        print(np.shape(net_predictions))
        num = 0
        for index, prediction in enumerate(net_predictions):

            for i in range(prediction.shape[0]):
                np.save(self.conf.sample_dir + "netpred" + str(num) + ".npy",
                        prediction[i])
                num += 1
                imsave(
                    prediction[i], self.conf.sample_dir +
                    str(index * prediction.shape[0] + i) + 'net.png')

        return np.mean(losses), np.mean(accuracies), m_ious[-1]
コード例 #16
0
ファイル: actions.py プロジェクト: moliqingcha/DACN
    def train(self):

        self.train_summary = self.config_summary('train')
        self.valid_summary = self.config_summary('valid')

        if self.conf.reload_epoch > 0:
            self.reload(self.conf.reload_epoch)

        train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
        valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)

        valid_loss_list = []
        train_loss_list = []

        train_acc_list = []
        valid_acc_list = []

        train_miou_list = []
        valid_miou_list = []

        self.sess.run(tf.local_variables_initializer())

        for epoch_num in range(self.conf.max_epoch):

            if epoch_num % self.conf.test_step == 1:
                inputs, annotations = valid_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: False
                }
                #loss, summary = self.sess.run([self.loss_op, self.valid_summary], feed_dict=feed_dict)
                loss, accuracy, m_iou, _ = self.sess.run(
                    [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
                    feed_dict=feed_dict)
                #self.save_summary(summary, epoch_num)

                print(epoch_num, '----valid loss', loss)

                # loss
                valid_loss_list.append(loss)
                np.save(self.conf.record_dir + "valid_loss.npy",
                        np.array(valid_loss_list))
                # acc
                valid_acc_list.append(accuracy)
                np.save(self.conf.record_dir + "valid_acc.npy",
                        np.array(valid_acc_list))
                # miou
                valid_miou_list.append(m_iou)
                np.save(self.conf.record_dir + "valid_miou.npy",
                        np.array(valid_miou_list))

                #########################################################################
                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: True
                }
                haha, loss, accuracy, m_iou, _ = self.sess.run(
                    [
                        self.train_op, self.loss_op, self.accuracy_op,
                        self.m_iou, self.miou_op
                    ],
                    feed_dict=feed_dict)

                print(epoch_num, '----train loss', loss)

                # loss
                train_loss_list.append(loss)
                np.save(self.conf.record_dir + "train_loss.npy",
                        np.array(train_loss_list))
                # acc
                train_acc_list.append(accuracy)
                np.save(self.conf.record_dir + "train_acc.npy",
                        np.array(train_acc_list))
                # miou
                train_miou_list.append(m_iou)
                np.save(self.conf.record_dir + "train_miou.npy",
                        np.array(train_miou_list))

            elif epoch_num % self.conf.summary_step == 1:
                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: False
                }
                #loss, _, summary = self.sess.run([self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict)
                #self.save_summary(summary, epoch_num)
                #print(epoch_num)

                #train_loss_list.append(loss)
                #np.save(self.conf.record_dir+"train_loss.npy",np.array(train_loss_list))
            else:

                inputs, annotations = train_reader.next_batch(
                    self.conf.batchsize)
                feed_dict = {
                    self.inputs: inputs,
                    self.annotations: annotations,
                    self.is_train: True
                }
                loss, _ = self.sess.run([self.loss_op, self.train_op],
                                        feed_dict=feed_dict)

                print(epoch_num)

            if epoch_num % self.conf.save_step == 1:
                self.save(epoch_num)