def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     if self.conf.data_type == '2D':
         train_reader = H5DataLoader(self.conf.data_dir +
                                     self.conf.train_data)
         valid_reader = H5DataLoader(self.conf.data_dir +
                                     self.conf.valid_data)
     else:
         train_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.train_data, self.input_shape)
         valid_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.valid_data, self.input_shape)
     for epoch_num in range(self.conf.max_step):
         if epoch_num % self.conf.test_interval == 0:
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary, accuracy, dice_accuracy = self.sess.run(
                 [
                     self.loss_op, self.valid_summary, self.accuracy_op,
                     self.dice_accuracy_op
                 ],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
             print('----valid loss', loss)
             print('----valid accuracy', accuracy)
             print('----valid dice accuracy', dice_accuracy)
         elif epoch_num % self.conf.summary_interval == 0:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary, _, accuracy, dice_accuracy= self.sess.run(
                 [self.loss_op, self.train_summary, self.train_op, self.accuracy_op, \
                 self.dice_accuracy_op], feed_dict=feed_dict)
             print('----train loss', loss)
             print('----train accuracy', accuracy)
             print('----train dice accuracy', dice_accuracy)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         if epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num + self.conf.reload_step)
 def predict(self, test_step, data_index, sub_batch_index, test_type):
     print('---->predicting ', test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     if test_type == 'valid':
         test_reader = H53DDataLoader(self.conf.data_dir +
                                      self.conf.valid_data,
                                      self.input_shape,
                                      is_train=False)
     elif test_type == 'predict':
         test_reader = H53DDataLoader(self.conf.data_dir +
                                      self.conf.test_data,
                                      self.input_shape,
                                      is_train=False)
     else:
         print("invalid type")
         return
     predict_generator = test_reader.generate_data(
         data_index, sub_batch_index,
         [self.conf.depth, self.conf.height, self.conf.width],
         [self.conf.d_gap, self.conf.w_gap, self.conf.h_gap])
     s_predictions, s_labels = self.predict_func(predict_generator)
     # process label and data
     # process label
     expand_annotations = tf.expand_dims(s_labels,
                                         -1,
                                         name='s_labels/expand_dims')
     one_hot_annotations = tf.squeeze(expand_annotations,
                                      axis=[self.channel_axis],
                                      name='s_labels/squeeze')
     one_hot_annotations = tf.one_hot(one_hot_annotations,
                                      depth=self.conf.class_num,
                                      axis=self.channel_axis,
                                      name='s_labels/one_hot')
     # process data
     decoded_predictions = tf.argmax(s_predictions,
                                     self.channel_axis,
                                     name='accuracy/decode_pred')
     correct_prediction = tf.equal(tf.cast(s_labels, tf.int64),
                                   decoded_predictions,
                                   name='accuracy/predict_correct_pred')
     #accuracy
     accuracy_op = tf.reduce_mean(tf.cast(correct_prediction,
                                          tf.float32,
                                          name='accuracy/cast'),
                                  name='accuracy/accuracy_op')
     #loss
     loss = tf.reduce_mean(
         tf.losses.softmax_cross_entropy(one_hot_annotations,
                                         s_predictions))
     return accuracy_op, loss
 def test(self):
     sig = True
     print('---->testing ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     if self.conf.data_type == '2D':
         test_reader = H5DataLoader(
             self.conf.data_dir + self.conf.test_data, False)
     else:
         test_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.test_data, self.input_shape)
     self.sess.run(tf.local_variables_initializer())
     count = 0
     losses = []
     accuracies = []
     m_ious = []
     while sig:
         sig, inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs is None:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         loss, accuracy, m_iou, _ = self.sess.run(
             [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
             feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)
         count += 1
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
     print('Loss: ', np.mean(losses))
     print('Accuracy: ', np.mean(accuracies))
     print('M_iou: ', m_ious[-1])
示例#4
0
    def test(self):
        print('---->testing ', self.conf.test_step)
        if self.conf.test_step > 0:
            self.reload(self.conf.test_step)
        else:
            print("please set a reasonable test_step")
            return
        data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size,
                                     self.conf.validation_id,
                                     self.conf.overlap_stepsize)
        self.sess.run(tf.local_variables_initializer())
        # count = 0
        losses = []
        accuracies = []
        for i in range(data_reader.num_of_valid_patches):
            inputs, annotations, _ = data_reader.valid_next_batch()
            # if inputs.shape[0] < self.conf.batch:
            #     break

            # pseudo_inputs = np.zeros((1,32,32,32,2), dtype=np.float32)
            # pseudo_labels = np.zeros((1,32,32,32), dtype=np.float32)
            # CUT_MEAN = np.array((100.913811861, 121.187003401), dtype=np.float32)
            # pseudo_inputs -= CUT_MEAN

            feed_dict = {self.inputs: inputs, self.annotations: annotations}
            loss, accuracy = self.sess.run([self.loss_op, self.accuracy_op],
                                           feed_dict=feed_dict)
            print('values----->', loss, accuracy)
            # count += 1
            losses.append(loss)
            accuracies.append(accuracy)
        print('Loss: ', np.mean(losses))
        print('Accuracy: ', np.mean(accuracies))
示例#5
0
 def predict(self):
     print('---->predicting ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size, self.conf.validation_id, self.conf.overlap_stepsize)
     self.sess.run(tf.local_variables_initializer())
     predictions = {}
     for i in range(data_reader.num_of_valid_patches):
         inputs, annotations, location = data_reader.valid_next_batch()
         # if inputs.shape[0] < self.conf.batch:
         #     break
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         preds = self.sess.run(self.softmax_predictions, feed_dict=feed_dict)
         print('--->processing results for: ', location)
         for j in range(self.conf.patch_size):
             for k in range(self.conf.patch_size):
                 for l in range(self.conf.patch_size):
                     key = (location[0]+j, location[1]+k, location[2]+l)
                     if key not in predictions.keys():
                         predictions[key] = []
                     predictions[key].append(preds[0, j, k, l, :])
     print('--->averaging results')
     results = np.zeros((data_reader.t_d, data_reader.t_h, data_reader.t_w, self.conf.class_num), dtype=np.float32)
     for key in predictions.keys():
         results[key[0], key[1], key[2]] = np.mean(predictions[key], axis=0)
     print('--->saving results')
     save_filename = 'results' + str(self.conf.test_step) + '_sub' + str(self.conf.validation_id) + '_overlap' + str(self.conf.overlap_stepsize) +'.npy'
     save_file = os.path.join(self.conf.savedir, save_filename)
     np.save(save_file, results)
示例#6
0
 def predict(self):
     print('---->predicting ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     if self.conf.data_type == '2D':
         test_reader = H5DataLoader(
             self.conf.data_dir+self.conf.test_data, False)
     else:
         test_reader = H53DDataLoader(
             self.conf.data_dir+self.conf.test_data, self.input_shape)
     predictions = []
     while True:
         inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         predictions.append(self.sess.run(
             self.decoded_preds, feed_dict=feed_dict))
     print('----->saving predictions')
     for index, prediction in enumerate(predictions):
         for i in range(prediction.shape[0]):
             imsave(prediction[i], self.conf.sampledir +
                    str(index*prediction.shape[0]+i)+'.png')
示例#7
0
    def predict(self,test_step,data_index,sub_batch_index, test_type):
	print('---->predicting ', self.conf.test_step)
        if test_type == 'valid':
            test_reader = H53DDataLoader(
                self.conf.data_dir+self.conf.valid_data, self.input_shape,is_train=False)           
        elif test_type == 'predict':
	    print("start get predict data")
	    test_reader = H53DDataLoader(
                self.conf.data_dir+self.conf.test_data, self.input_shape,is_train=False)   
	    print("finish get data")
        else:
            print("invalid type")
            return
        predict_generator = test_reader.generate_data(data_index,sub_batch_index,[self.conf.depth,self.conf.height,self.conf.width],[self.conf.d_gap,self.conf.w_gap,self.conf.h_gap])
        s_predictions,s_labels = self.predict_func(predict_generator)
	print("finish concate ---------+++++")
	# process label and data 
        # process label
        expand_annotations = tf.expand_dims(
            s_labels, -1, name='s_labels/expand_dims')
        one_hot_annotations = tf.squeeze(
            expand_annotations, axis=[self.channel_axis],
            name='s_labels/squeeze')
        one_hot_annotations = tf.one_hot(
            one_hot_annotations, depth=self.conf.class_num,
            axis=self.channel_axis, name='s_labels/one_hot')
        # process data
        decoded_predictions = tf.argmax(
            s_predictions, self.channel_axis, name='accuracy/decode_pred')
        correct_prediction = tf.equal(
            tf.cast(s_labels,tf.int64), decoded_predictions,
            name='accuracy/predict_correct_pred')
        #accuracy
        accuracy_op = tf.reduce_mean(
            tf.cast(correct_prediction, tf.float32, name='accuracy/cast'),
            name='accuracy/accuracy_op')
        #loss
        loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(one_hot_annotations, s_predictions))
        #dice ratio
	DiceRatio, sublist = ops.dice_accuracy(decoded_predictions, s_labels, self.conf.class_num)
	print("session image labels-------vvvvvvv")
	accuracy, Loss, diceRatio, Sublist = self.sess.run([accuracy_op, loss,DiceRatio,sublist])
	del DiceRatio, sublist
	gc.collect()
	return accuracy, Loss, diceRatio, Sublist
示例#8
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     if self.conf.data_type == '2D':
         train_reader = H5DataLoader(
             self.conf.data_dir+self.conf.train_data)
         valid_reader = H5DataLoader(
             self.conf.data_dir+self.conf.valid_data)
     else:
         train_reader = H53DDataLoader(
             self.conf.data_dir+self.conf.train_data, self.input_shape)
         valid_reader = H53DDataLoader(
             self.conf.data_dir+self.conf.valid_data, self.input_shape)
     for epoch_num in range(self.conf.max_step+1):
         if epoch_num and epoch_num % self.conf.test_interval == 0:
             inputs, labels = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels}
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_step)
             print('----testing loss', loss)
         if epoch_num and epoch_num % self.conf.summary_interval == 0:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels}
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_step)
         else:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels}
             loss, _ = self.sess.run(
                 [self.loss_op, self.train_op], feed_dict=feed_dict)
             print('----training loss', loss)
         if epoch_num and epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num+self.conf.reload_step)
    def predict(self):
        print('---->predicting ', self.conf.test_step)
        if self.conf.test_step > 0:
            self.reload(self.conf.test_step)
        else:
            print("please set a reasonable test_step")
            return
        if self.conf.data_type == '2D':
            test_reader = H5DataLoader(
                self.conf.data_dir + self.conf.test_data, False)
        else:
            test_reader = H53DDataLoader(
                self.conf.data_dir + self.conf.test_data, self.input_shape)

        predictions = []
        sig = True
        final_inputs = []
        final_labels = []
        while sig:
            sig, inputs, labels = test_reader.next_batch(self.conf.batch)
            if inputs is None:
                break

            final_inputs.append(inputs)
            final_labels.append(labels)
            feed_dict = {self.inputs: inputs, self.labels: labels}
            predictions.append(
                self.sess.run(self.decoded_preds, feed_dict=feed_dict))
        final_inputs = np.array(final_inputs, 'uint8')
        final_labels = np.array(final_labels, 'uint8')

        print('----->saving predictions')
        for index, prediction in enumerate(predictions):
            for i in range(prediction.shape[0]):
                imsave(
                    prediction[i], self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '.png')
                scipy.misc.imsave(
                    self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '.jpg',
                    final_inputs[index][i])
                scipy.misc.imsave(
                    self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '_label.png',
                    final_labels[index][i] * 45)
示例#10
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size,
                                  self.conf.validation_id,
                                  self.conf.overlap_stepsize,
                                  self.conf.aug_flip, self.conf.aug_rotate)
     for train_step in range(1, self.conf.max_step + 1):
         # if train_step % self.conf.test_interval == 0:
         #     inputs, annotations = data_reader.valid_next_batch()
         #     feed_dict = {self.inputs: inputs,
         #                  self.annotations: annotations}
         #     loss, summary = self.sess.run(
         #         [self.loss_op, self.valid_summary], feed_dict=feed_dict)
         #     self.save_summary(summary, train_step+self.conf.reload_step)
         #     print('----testing loss', loss)
         # el
         if train_step % self.conf.summary_interval == 0:
             inputs, annotations = data_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, train_step + self.conf.reload_step)
         else:
             inputs, annotations = data_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _ = self.sess.run([self.loss_op, self.train_op],
                                     feed_dict=feed_dict)
             print('----training loss', loss)
         if train_step % self.conf.save_interval == 0:
             self.save(train_step + self.conf.reload_step)