コード例 #1
0
ファイル: unet_pos.py プロジェクト: xzluo97/LGE_SRSCN
def dice_score(predictions, labels):
    """
    Return the dice score based on dense predictions and labels.
    :param predictions: list of output predictions
    :param labels: list of ground truths
    """
    assert len(predictions) == len(
        labels), "Number of predictions and labels don't equal."
    n_class = labels[0].shape[-1]
    dice = np.array([])
    n = len(predictions)
    eps = 1.
    for i in range(n):
        pred = np.array(predictions[i])
        label = util.crop_to_shape(np.array(labels[i]), pred.shape)
        mask = np.where(np.equal(np.max(pred, -1, keepdims=True), pred),
                        np.ones_like(pred), np.zeros_like(pred))
        d = 0.
        for k in range(1, n_class):
            numerator = 2 * np.sum(mask[..., k] * label[..., k])
            denominator = np.sum(mask[..., k] + label[..., k])
            d += numerator / (eps + denominator)

        dice = np.hstack((dice, d / (n_class - 1)))
    return dice
コード例 #2
0
ファイル: ACNN.py プロジェクト: xzluo97/LGE_SRSCN
    def train(self, data_provider, model_path, training_iters=10, epochs=10, display_step=1, restore=False):
        """
        Lauches the training process

        :param data_provider: callable returning training data
        :param model_path: path where to store checkpoints
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored
        """
        
        save_path = os.path.join(model_path, 'model.ckpt')
        if epochs == 0:
            return save_path
        
        init = self._initialize(training_iters, model_path, restore)

        with tf.Session() as sess:
            sess.run(init)
            
            if restore:
                ckpt = tf.train.get_checkpoint_state(model_path)
                if ckpt and ckpt.model_checkpoint_path:
                    var_list = tf.global_variables(scope='autoencoder') + self.__optimizer.variables() + [
                        tf.train.get_global_step()]
                    self.restore(sess, ckpt.model_checkpoint_path, var_list=var_list)

            summary_writer = tf.summary.FileWriter(model_path, graph=sess.graph)

            logging.info("Start Optimization!")

            for epoch in range(epochs):
                total_loss = 0.
                for step in range((epoch * training_iters), ((epoch + 1) * training_iters)):
                    _, batch_y, _ = data_provider(self.batch_size)
                    
                    decodes = sess.run(self.__decodes, feed_dict={self.__labels: batch_y,
                                                                  self.__train_phase: False})
    
                    _, batch_loss = sess.run((self.__optimizer, self.__cost),
                                             feed_dict={self.__labels: crop_to_shape(batch_y, decodes.shape),
                                                        self.__train_phase: True})
    
                    if step % display_step == 0:
                        logging.info("Iteration {:}, Mini-batch loss= {:.4f}".format(step, batch_loss))
                        summary_str = sess.run(self.summary_op, feed_dict={self.__labels: batch_y,
                                                                           self.__train_phase: True})
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()
                    
                    total_loss += batch_loss

                logging.info("Epoch {:}, Average mini-batch loss= {:.4f}".format(epoch, total_loss / training_iters))
                
                save_path = self.save(sess, save_path, "checkpoint")
                
            logging.info("Optimization Finished!")
        
        return save_path
コード例 #3
0
ファイル: unet_pos.py プロジェクト: xzluo97/LGE_SRSCN
    def store_prediction(self, sess, batch_x, batch_y, batch_affine):

        n = len(batch_y)
        loss = np.zeros([n])
        dice = np.zeros([n])
        batch_pred = []

        sess.run(tf.local_variables_initializer())
        for i in range(n):

            pred = sess.run(self.net.predictor,
                            feed_dict={
                                self.net.x: batch_x[i],
                                self.net.y: batch_y[i],
                                self.net.p: self.p_dummy,
                                self.net.dropout_rate: 0.,
                                self.net.train_phase: False,
                                self.net.need_pos: False
                            })
            pred_shape = pred.shape
            batch_pred.append(pred)

            loss[i], dice[i] = sess.run(
                [self.net.cost, self.net.dice_score],
                feed_dict={
                    self.net.x: batch_x[i],
                    self.net.y: util.crop_to_shape(batch_y[i], pred_shape),
                    self.net.p: self.p_dummy,
                    self.net.dropout_rate: 0.,
                    self.net.train_phase: False,
                    self.net.need_pos: False
                })

            batch_x[i] = np.expand_dims(batch_x[i], axis=0).transpose(
                (0, 2, 3, 1, 4))
            batch_y[i] = np.expand_dims(batch_y[i], axis=0).transpose(
                (0, 2, 3, 1, 4))
            batch_pred[i] = np.expand_dims(batch_pred[i], axis=0).transpose(
                (0, 2, 3, 1, 4))

        acc, auc, sens, spec = sess.run(
            [self.net.acc, self.net.auc, self.net.sens, self.net.spec])
        logging.info(
            "Validation Error= {:.2f}%, Loss= {:.4f}, Dice score= {:.4f}, AUC= {:.4f}, Sensitivity= {:.2f}%, "
            "Specificity= {:.2f}% ".format((1 - acc) * 100, np.mean(loss),
                                           np.mean(dice), auc, sens * 100,
                                           spec * 100))
        util.save_prediction(batch_x, batch_y, batch_pred,
                             self.prediction_path)
        util.save_prediction_1(batch_pred, batch_affine, self.prediction_path)

        for i in range(n):
            batch_x[i] = np.squeeze(batch_x[i], axis=0).transpose((2, 0, 1, 3))
            batch_y[i] = np.squeeze(batch_y[i], axis=0).transpose((2, 0, 1, 3))

        return acc, np.mean(dice), auc, sens, spec
コード例 #4
0
ファイル: unet_pos.py プロジェクト: xzluo97/LGE_SRSCN
def acc_rate(predictions, labels):
    """
    Return the error rate based on dense predictions and labels.
    :param predictions: list of output predictions
    :param labels: list of ground truths
    """
    assert len(predictions) == len(
        labels), "Number of predictions and labels don't equal."
    err = np.array([])
    n = len(predictions)
    for i in range(n):
        err = np.hstack((err, (100.0 * np.average(
            np.argmax(predictions[i], -1) == np.argmax(
                util.crop_to_shape(labels[i], predictions[i].shape), -1)))))
    return err
コード例 #5
0
ファイル: unet_pos.py プロジェクト: xzluo97/LGE_SRSCN
def auc_score(predictions, labels):
    """
    Return the auc score based on dense predictions and labels.
    :param predictions: list of output predictions
    :param labels: list of ground truths
    """
    assert len(predictions) == len(
        labels), "Number of predictions and labels don't equal."
    auc = np.array([])
    n = len(predictions)
    n_class = labels[0].shape[-1]
    for i in range(n):
        flat_score = np.reshape(predictions[i], [-1, n_class])
        flat_true = np.reshape(
            util.crop_to_shape(labels[i], predictions[i].shape), [-1, n_class])
        auc = np.hstack((auc, roc_auc_score(flat_true, flat_score)))
    return auc
コード例 #6
0
ファイル: unet_pos.py プロジェクト: xzluo97/LGE_SRSCN
    def train(self,
              train_data_provider,
              val_data_provider,
              train_original_data_provider,
              validation_batch_size,
              model_path,
              training_iters=10,
              epochs=100,
              dropout=0.75,
              clip_gradient=False,
              display_step=1,
              restore=True,
              write_graph=False,
              prediction_path='validation_prediction'):
        """
        Launches the training process

        :param train_data_provider: callable returning training data
        :param val_data_provider: callable returning validation data
        :param validation_batch_size: number of data for validation
        :param model_path: path where to store checkpoints
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param dropout: dropout probability
        :param clip_gradient: whether to apply gradient clipping
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored
        :param write_graph: Flag if the computation graph should be written as protobuf file to the output path
        :param prediction_path: path where to save predictions on each epoch
        """

        save_path = os.path.join(model_path, "best_model.ckpt")
        goon_path = os.path.join(model_path, "goon_model.ckpt")

        init = self._initialize(training_iters, clip_gradient, model_path,
                                restore, prediction_path)

        with tf.Session(config=config) as sess:
            if write_graph:
                tf.train.write_graph(sess.graph_def, model_path, "graph.pb",
                                     False)

            # initialization
            sess.run(init)

            # ACNN regularization
            if self.net.regularizer_type == 'anatomical_constraint':
                ae_ckpt = tf.train.get_checkpoint_state(self.net.abs_ae_path)
                if ae_ckpt and ae_ckpt.model_checkpoint_path:
                    logging.info("Model restored from file: {:}".format(
                        ae_ckpt.model_checkpoint_path))
                    # print([v.name for v in self.net.ae_variables])
                    ae_var_list = dict(
                        (v.name.lstrip('cost_function/').rstrip(':0'), v)
                        for v in self.net.ae_variables)
                    self.net.restore(sess,
                                     ae_ckpt.model_checkpoint_path,
                                     var_list=ae_var_list)

            # restore model
            if restore:
                ckpt = tf.train.get_checkpoint_state(
                    model_path, latest_filename='goon_checkpoint')
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(sess,
                                     ckpt.model_checkpoint_path,
                                     var_list=self.net.training_variables +
                                     [tf.train.get_global_step()])

            # create summary writer for training summaries
            summary_writer = tf.summary.FileWriter(model_path,
                                                   graph=sess.graph)

            # read validation data
            test_x, test_y, test_affine, _ = val_data_provider(
                validation_batch_size)
            # read the original train data
            train_x, train_y, train_affine, _ = train_original_data_provider(
                25)
            # visualize performance on validation data
            self.store_prediction(sess, test_x, test_y, test_affine)

            test_acc = np.array([])
            test_dice = np.array([])
            test_auc = np.array([])
            test_sens = np.array([])
            test_spec = np.array([])

            if epochs == 0:
                return save_path, test_acc, test_dice, test_auc, test_sens, test_spec

            logging.info(
                "Start U-net optimization based on loss function: {} and regularizer type: {}"
                .format(self.net.cost_name, self.net.regularizer_type))
            if self.net.regularizer_type is not None:
                logging.info("Current regularization coefficient: {}".format(
                    self.net.regularization_coefficient))

            lr = 0.
            avg_gradients = None
            for epoch in range(epochs):
                total_loss = 0.
                for step in range((epoch * training_iters),
                                  ((epoch + 1) * training_iters)):
                    # read training data
                    batch_x, batch_y, _, batch_position = train_data_provider(
                        self.batch_size)

                    # get output shape
                    prediction = sess.run(self.net.predictor,
                                          feed_dict={
                                              self.net.x: batch_x,
                                              self.net.y: batch_y,
                                              self.net.p: batch_position,
                                              self.net.dropout_rate: 0.,
                                              self.net.train_phase: False,
                                              self.net.need_pos: False
                                          })
                    pred_shape = prediction.shape

                    # optimization operation (back-propagation)

                    _, loss, lr, gradients = sess.run(
                        [
                            self.train_op, self.net.cost,
                            self.learning_rate_node, self.net.gradients_node
                        ],
                        feed_dict={
                            self.net.x: batch_x,
                            self.net.y:
                            util.crop_to_shape(batch_y, pred_shape),
                            self.net.p: batch_position,
                            self.net.dropout_rate: dropout,
                            self.net.train_phase: True,
                            self.net.need_pos: True
                        })

                    # add normalized gradients to summaries
                    if self.net.summaries and self.norm_grads:
                        avg_gradients = _update_avg_gradients(
                            avg_gradients, gradients, step)
                        norm_gradients = [
                            np.linalg.norm(gradient)
                            for gradient in avg_gradients
                        ]
                        self.norm_gradients_node.assign(norm_gradients).eval()

                    # display mini-batch statistics
                    if step % display_step == 0:
                        self.output_minibatch_stats(
                            sess, summary_writer, step, batch_x,
                            util.crop_to_shape(batch_y, pred_shape))

                    total_loss += loss

                # display epoch statistics
                self.output_epoch_stats(epoch, total_loss, training_iters, lr)

                # save the current model
                model_path_per_epoch = os.path.join(
                    model_path, "model_{}.ckpt".format(epoch))
                self.net.save(
                    sess,
                    model_path_per_epoch,
                    latest_filename='model_{}_checkpoint'.format(epoch))
                self.net.save(sess,
                              goon_path,
                              latest_filename='goon_checkpoint')

                # visualize and display validation performance and metrics
                acc, dice, auc, sens, spec = self.store_prediction(
                    sess, test_x, test_y, test_affine)
                print(
                    '#################### result of original train data ######################'
                )
                self.store_prediction(sess, train_x, train_y, train_affine)

                # save the current model if it is the best one hitherto
                if epoch > 0 and dice > np.max(test_dice):
                    save_path = self.net.save(
                        sess, save_path, latest_filename='best_checkpoint')

                # store the validation metrics
                test_acc = np.hstack((test_acc, acc))
                test_dice = np.hstack((test_dice, dice))
                test_auc = np.hstack((test_auc, auc))
                test_sens = np.hstack((test_sens, sens))
                test_spec = np.hstack((test_spec, spec))
            logging.info("Optimization Finished!")

            return save_path, test_acc, test_dice, test_auc, test_sens, test_spec