def dice_score(predictions, labels): """ Return the dice score based on dense predictions and labels. :param predictions: list of output predictions :param labels: list of ground truths """ assert len(predictions) == len( labels), "Number of predictions and labels don't equal." n_class = labels[0].shape[-1] dice = np.array([]) n = len(predictions) eps = 1. for i in range(n): pred = np.array(predictions[i]) label = util.crop_to_shape(np.array(labels[i]), pred.shape) mask = np.where(np.equal(np.max(pred, -1, keepdims=True), pred), np.ones_like(pred), np.zeros_like(pred)) d = 0. for k in range(1, n_class): numerator = 2 * np.sum(mask[..., k] * label[..., k]) denominator = np.sum(mask[..., k] + label[..., k]) d += numerator / (eps + denominator) dice = np.hstack((dice, d / (n_class - 1))) return dice
def train(self, data_provider, model_path, training_iters=10, epochs=10, display_step=1, restore=False): """ Lauches the training process :param data_provider: callable returning training data :param model_path: path where to store checkpoints :param training_iters: number of training mini batch iteration :param epochs: number of epochs :param display_step: number of steps till outputting stats :param restore: Flag if previous model should be restored """ save_path = os.path.join(model_path, 'model.ckpt') if epochs == 0: return save_path init = self._initialize(training_iters, model_path, restore) with tf.Session() as sess: sess.run(init) if restore: ckpt = tf.train.get_checkpoint_state(model_path) if ckpt and ckpt.model_checkpoint_path: var_list = tf.global_variables(scope='autoencoder') + self.__optimizer.variables() + [ tf.train.get_global_step()] self.restore(sess, ckpt.model_checkpoint_path, var_list=var_list) summary_writer = tf.summary.FileWriter(model_path, graph=sess.graph) logging.info("Start Optimization!") for epoch in range(epochs): total_loss = 0. for step in range((epoch * training_iters), ((epoch + 1) * training_iters)): _, batch_y, _ = data_provider(self.batch_size) decodes = sess.run(self.__decodes, feed_dict={self.__labels: batch_y, self.__train_phase: False}) _, batch_loss = sess.run((self.__optimizer, self.__cost), feed_dict={self.__labels: crop_to_shape(batch_y, decodes.shape), self.__train_phase: True}) if step % display_step == 0: logging.info("Iteration {:}, Mini-batch loss= {:.4f}".format(step, batch_loss)) summary_str = sess.run(self.summary_op, feed_dict={self.__labels: batch_y, self.__train_phase: True}) summary_writer.add_summary(summary_str, step) summary_writer.flush() total_loss += batch_loss logging.info("Epoch {:}, Average mini-batch loss= {:.4f}".format(epoch, total_loss / training_iters)) save_path = self.save(sess, save_path, "checkpoint") logging.info("Optimization Finished!") return save_path
def store_prediction(self, sess, batch_x, batch_y, batch_affine): n = len(batch_y) loss = np.zeros([n]) dice = np.zeros([n]) batch_pred = [] sess.run(tf.local_variables_initializer()) for i in range(n): pred = sess.run(self.net.predictor, feed_dict={ self.net.x: batch_x[i], self.net.y: batch_y[i], self.net.p: self.p_dummy, self.net.dropout_rate: 0., self.net.train_phase: False, self.net.need_pos: False }) pred_shape = pred.shape batch_pred.append(pred) loss[i], dice[i] = sess.run( [self.net.cost, self.net.dice_score], feed_dict={ self.net.x: batch_x[i], self.net.y: util.crop_to_shape(batch_y[i], pred_shape), self.net.p: self.p_dummy, self.net.dropout_rate: 0., self.net.train_phase: False, self.net.need_pos: False }) batch_x[i] = np.expand_dims(batch_x[i], axis=0).transpose( (0, 2, 3, 1, 4)) batch_y[i] = np.expand_dims(batch_y[i], axis=0).transpose( (0, 2, 3, 1, 4)) batch_pred[i] = np.expand_dims(batch_pred[i], axis=0).transpose( (0, 2, 3, 1, 4)) acc, auc, sens, spec = sess.run( [self.net.acc, self.net.auc, self.net.sens, self.net.spec]) logging.info( "Validation Error= {:.2f}%, Loss= {:.4f}, Dice score= {:.4f}, AUC= {:.4f}, Sensitivity= {:.2f}%, " "Specificity= {:.2f}% ".format((1 - acc) * 100, np.mean(loss), np.mean(dice), auc, sens * 100, spec * 100)) util.save_prediction(batch_x, batch_y, batch_pred, self.prediction_path) util.save_prediction_1(batch_pred, batch_affine, self.prediction_path) for i in range(n): batch_x[i] = np.squeeze(batch_x[i], axis=0).transpose((2, 0, 1, 3)) batch_y[i] = np.squeeze(batch_y[i], axis=0).transpose((2, 0, 1, 3)) return acc, np.mean(dice), auc, sens, spec
def acc_rate(predictions, labels): """ Return the error rate based on dense predictions and labels. :param predictions: list of output predictions :param labels: list of ground truths """ assert len(predictions) == len( labels), "Number of predictions and labels don't equal." err = np.array([]) n = len(predictions) for i in range(n): err = np.hstack((err, (100.0 * np.average( np.argmax(predictions[i], -1) == np.argmax( util.crop_to_shape(labels[i], predictions[i].shape), -1))))) return err
def auc_score(predictions, labels): """ Return the auc score based on dense predictions and labels. :param predictions: list of output predictions :param labels: list of ground truths """ assert len(predictions) == len( labels), "Number of predictions and labels don't equal." auc = np.array([]) n = len(predictions) n_class = labels[0].shape[-1] for i in range(n): flat_score = np.reshape(predictions[i], [-1, n_class]) flat_true = np.reshape( util.crop_to_shape(labels[i], predictions[i].shape), [-1, n_class]) auc = np.hstack((auc, roc_auc_score(flat_true, flat_score))) return auc
def train(self, train_data_provider, val_data_provider, train_original_data_provider, validation_batch_size, model_path, training_iters=10, epochs=100, dropout=0.75, clip_gradient=False, display_step=1, restore=True, write_graph=False, prediction_path='validation_prediction'): """ Launches the training process :param train_data_provider: callable returning training data :param val_data_provider: callable returning validation data :param validation_batch_size: number of data for validation :param model_path: path where to store checkpoints :param training_iters: number of training mini batch iteration :param epochs: number of epochs :param dropout: dropout probability :param clip_gradient: whether to apply gradient clipping :param display_step: number of steps till outputting stats :param restore: Flag if previous model should be restored :param write_graph: Flag if the computation graph should be written as protobuf file to the output path :param prediction_path: path where to save predictions on each epoch """ save_path = os.path.join(model_path, "best_model.ckpt") goon_path = os.path.join(model_path, "goon_model.ckpt") init = self._initialize(training_iters, clip_gradient, model_path, restore, prediction_path) with tf.Session(config=config) as sess: if write_graph: tf.train.write_graph(sess.graph_def, model_path, "graph.pb", False) # initialization sess.run(init) # ACNN regularization if self.net.regularizer_type == 'anatomical_constraint': ae_ckpt = tf.train.get_checkpoint_state(self.net.abs_ae_path) if ae_ckpt and ae_ckpt.model_checkpoint_path: logging.info("Model restored from file: {:}".format( ae_ckpt.model_checkpoint_path)) # print([v.name for v in self.net.ae_variables]) ae_var_list = dict( (v.name.lstrip('cost_function/').rstrip(':0'), v) for v in self.net.ae_variables) self.net.restore(sess, ae_ckpt.model_checkpoint_path, var_list=ae_var_list) # restore model if restore: ckpt = tf.train.get_checkpoint_state( model_path, latest_filename='goon_checkpoint') if ckpt and ckpt.model_checkpoint_path: self.net.restore(sess, ckpt.model_checkpoint_path, var_list=self.net.training_variables + [tf.train.get_global_step()]) # create summary writer for training summaries summary_writer = tf.summary.FileWriter(model_path, graph=sess.graph) # read validation data test_x, test_y, test_affine, _ = val_data_provider( validation_batch_size) # read the original train data train_x, train_y, train_affine, _ = train_original_data_provider( 25) # visualize performance on validation data self.store_prediction(sess, test_x, test_y, test_affine) test_acc = np.array([]) test_dice = np.array([]) test_auc = np.array([]) test_sens = np.array([]) test_spec = np.array([]) if epochs == 0: return save_path, test_acc, test_dice, test_auc, test_sens, test_spec logging.info( "Start U-net optimization based on loss function: {} and regularizer type: {}" .format(self.net.cost_name, self.net.regularizer_type)) if self.net.regularizer_type is not None: logging.info("Current regularization coefficient: {}".format( self.net.regularization_coefficient)) lr = 0. avg_gradients = None for epoch in range(epochs): total_loss = 0. for step in range((epoch * training_iters), ((epoch + 1) * training_iters)): # read training data batch_x, batch_y, _, batch_position = train_data_provider( self.batch_size) # get output shape prediction = sess.run(self.net.predictor, feed_dict={ self.net.x: batch_x, self.net.y: batch_y, self.net.p: batch_position, self.net.dropout_rate: 0., self.net.train_phase: False, self.net.need_pos: False }) pred_shape = prediction.shape # optimization operation (back-propagation) _, loss, lr, gradients = sess.run( [ self.train_op, self.net.cost, self.learning_rate_node, self.net.gradients_node ], feed_dict={ self.net.x: batch_x, self.net.y: util.crop_to_shape(batch_y, pred_shape), self.net.p: batch_position, self.net.dropout_rate: dropout, self.net.train_phase: True, self.net.need_pos: True }) # add normalized gradients to summaries if self.net.summaries and self.norm_grads: avg_gradients = _update_avg_gradients( avg_gradients, gradients, step) norm_gradients = [ np.linalg.norm(gradient) for gradient in avg_gradients ] self.norm_gradients_node.assign(norm_gradients).eval() # display mini-batch statistics if step % display_step == 0: self.output_minibatch_stats( sess, summary_writer, step, batch_x, util.crop_to_shape(batch_y, pred_shape)) total_loss += loss # display epoch statistics self.output_epoch_stats(epoch, total_loss, training_iters, lr) # save the current model model_path_per_epoch = os.path.join( model_path, "model_{}.ckpt".format(epoch)) self.net.save( sess, model_path_per_epoch, latest_filename='model_{}_checkpoint'.format(epoch)) self.net.save(sess, goon_path, latest_filename='goon_checkpoint') # visualize and display validation performance and metrics acc, dice, auc, sens, spec = self.store_prediction( sess, test_x, test_y, test_affine) print( '#################### result of original train data ######################' ) self.store_prediction(sess, train_x, train_y, train_affine) # save the current model if it is the best one hitherto if epoch > 0 and dice > np.max(test_dice): save_path = self.net.save( sess, save_path, latest_filename='best_checkpoint') # store the validation metrics test_acc = np.hstack((test_acc, acc)) test_dice = np.hstack((test_dice, dice)) test_auc = np.hstack((test_auc, auc)) test_sens = np.hstack((test_sens, sens)) test_spec = np.hstack((test_spec, spec)) logging.info("Optimization Finished!") return save_path, test_acc, test_dice, test_auc, test_sens, test_spec