Пример #1
0
def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl,
                     labels_sd_vl, images_td_tr, images_td_vl):

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_num_uda)
        np.random.seed(exp_config.run_num_uda)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_sd_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_sd')
        images_td_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_td')
        labels_sd_pl = tf.placeholder(tf.uint8,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size),
                                      name='labels_sd')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # ================================================================
        images_sd_normalized, _ = model.normalize(images_sd_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=False)
        images_td_normalized, _ = model.normalize(images_td_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=True)

        # ================================================================
        # get logit predictions from the segmentation network
        # ================================================================
        predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized,
                                                          exp_config,
                                                          training_pl,
                                                          scope_reuse=False)

        # ================================================================
        # get all features from the segmentation network
        # ================================================================
        seg_features_sd = model.get_all_features(images_sd_normalized,
                                                 exp_config,
                                                 scope_reuse=True)
        seg_features_td = model.get_all_features(images_td_normalized,
                                                 exp_config,
                                                 scope_reuse=True)

        # ================================================================
        # resize all features to the same size
        # ================================================================
        images_sd_features_resized = model.resize_features(
            [images_sd_normalized] + list(seg_features_sd), (64, 64),
            'resize_sd_xfeat')
        images_td_features_resized = model.resize_features(
            [images_td_normalized] + list(seg_features_td), (64, 64),
            'resize_td_xfeat')

        # ================================================================
        # discriminator on features
        # ================================================================
        d_logits_sd = model.discriminator(images_sd_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=False)
        d_logits_td = model.discriminator(images_td_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=True)

        # ================================================================
        # add ops for calculation of the discriminator loss
        # ================================================================
        d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(d_logits_sd), logits=d_logits_sd)
        d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(d_logits_td), logits=d_logits_td)
        loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td)
        tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op)

        # ================================================================
        # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space
        # ================================================================
        loss_g_op = model.loss_invariance(d_logits_td)
        tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op)

        # ================================================================
        # add ops for calculation of the supervised segmentation loss
        # ================================================================
        loss_seg_op = model.loss(predicted_seg_sd_logits,
                                 labels_sd_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op)

        # ================================================================
        # total training loss for uda
        # ================================================================
        loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op
        tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op)

        # ================================================================
        # merge all summaries
        # ================================================================
        summary_scalars = tf.summary.merge_all()

        # ================================================================
        # divide the vars into segmentation network, normalization network and the discriminator network
        # ================================================================
        i2l_vars = []
        normalization_vars = []
        discriminator_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'discriminator' in var_name:
                discriminator_vars.append(v)

        # ================================================================
        # add optimization ops
        # ================================================================
        train_i2l_op = model.training_step(
            loss_total_op,
            i2l_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        train_discriminator_op = model.training_step(
            loss_d_op,
            discriminator_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        eval_loss = model.evaluation_i2l_uda_invariant_features(
            predicted_seg_sd_logits,
            labels_sd_pl,
            images_sd_pl,
            d_logits_td,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # add ops for adding image summary to tensorboard
        # ================================================================
        summary_images = model.write_image_summary_uda_invariant_features(
            predicted_seg_sd_logits, labels_sd_pl, images_sd_pl,
            exp_config.nlabels)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a file writer object
        # This writes Summary protocol buffers to event files.
        # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md
        # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it.
        # The class updates the file contents asynchronously.
        # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # create savers
        # ================================================================
        saver = tf.train.Saver(var_list=i2l_vars)
        saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='vl_error_seg')
        vl_error_seg_summary = tf.summary.scalar('validation/loss_seg',
                                                 vl_error_seg)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='vl_error_invariance')
        vl_error_invariance_summary = tf.summary.scalar(
            'validation/loss_invariance', vl_error_invariance)
        vl_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='vl_error_total')
        vl_error_total_summary = tf.summary.scalar('validation/loss_total',
                                                   vl_error_total)
        vl_summary = tf.summary.merge([
            vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary,
            vl_error_total_summary
        ])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='tr_error_seg')
        tr_error_seg_summary = tf.summary.scalar('training/loss_seg',
                                                 tr_error_seg)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='tr_error_invariance')
        tr_error_invariance_summary = tf.summary.scalar(
            'training/loss_invariance', tr_error_invariance)
        tr_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='tr_error_total')
        tr_error_total_summary = tf.summary.scalar('training/loss_total',
                                                   tr_error_total)
        tr_summary = tf.summary.merge([
            tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary,
            tr_error_total_summary
        ])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # ================================================================
        if exp_config.train_from_scratch is False:
            logging.info(
                '============================================================')
            path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_dice.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_lowest_loss.restore(sess, checkpoint_path)

        # ================================================================
        # run training steps
        # ================================================================
        step = 0
        lowest_loss = 10000.0
        validation_total_loss_list = []

        while (step < exp_config.max_steps):

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(images_sd=images_sd_tr,
                                             labels_sd=labels_sd_tr,
                                             images_td=images_td_tr,
                                             batch_size=exp_config.batch_size):

                x_sd, y_sd, x_td = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_sd_pl: x_sd,
                    labels_sd_pl: y_sd,
                    images_td_pl: x_td,
                    training_pl: True
                }

                # ================================================
                # update i2l and D successively
                # ================================================
                sess.run(train_i2l_op, feed_dict=feed_dict)
                sess.run(train_discriminator_op, feed_dict=feed_dict)

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        '============== Updating summary at step %d ' % step)
                    summary_writer.add_summary(
                        sess.run(summary_scalars, feed_dict=feed_dict), step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('============== Training Data Eval:')
                    train_loss_seg, train_dice, train_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_tr, labels_sd_tr,
                        images_td_tr, exp_config.batch_size)

                    # total training loss
                    train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error_seg: train_loss_seg,
                                                  tr_dice: train_dice,
                                                  tr_error_invariance:
                                                  train_loss_invariance,
                                                  tr_error_total:
                                                  train_total_loss
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    logging.info(
                        '============== Periodically saving checkpoint:')
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('============== Validation Data Eval:')
                    val_loss_seg, val_dice, val_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_vl, labels_sd_vl,
                        images_td_vl, exp_config.batch_size)

                    # total val loss
                    val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance
                    validation_total_loss_list.append(val_total_loss)

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error_seg: val_loss_seg,
                                                  vl_dice: val_dice,
                                                  vl_error_invariance:
                                                  val_loss_invariance,
                                                  vl_error_total:
                                                  val_total_loss
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # update tensorboard summary of images
                    # ===========================
                    summary_writer.add_summary(
                        sess.run(summary_images,
                                 feed_dict={
                                     images_sd_pl: x_sd,
                                     labels_sd_pl: y_sd,
                                     training_pl: False
                                 }), step)
                    summary_writer.flush()

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    window_length = 5
                    if len(validation_total_loss_list) < window_length + 1:
                        expo_moving_avg_loss_value = validation_total_loss_list[
                            -1]
                    else:
                        expo_moving_avg_loss_value = utils.exponential_moving_average(
                            validation_total_loss_list,
                            window=window_length)[-1]

                    if expo_moving_avg_loss_value < lowest_loss:
                        lowest_loss = val_total_loss
                        lowest_loss_file = os.path.join(
                            log_dir, 'models/lowest_loss.ckpt')
                        saver_lowest_loss.save(sess,
                                               lowest_loss_file,
                                               global_step=step)
                        logging.info(
                            '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********'
                            % step)

                # ================================================
                # increment step
                # ================================================
                step += 1

        # ================================================================
        # close tf session
        # ================================================================
        sess.close()

    return 0
Пример #2
0
def run_training(log_dir,
                 image,
                 label,
                 atlas,
                 continue_run,
                 log_dir_first_TD_subject=''):

    # ============================
    # down sample the atlas - the losses will be evaluated in the downsampled space
    # ============================
    atlas_downsampled = rescale(atlas, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                order=1,
                                preserve_range=True,
                                multichannel=True,
                                mode='constant')
    atlas_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        atlas_downsampled, int(256 / exp_config.downsampling_factor_x))

    label_onehot = utils.make_onehot(label, exp_config.nlabels)
    label_onehot_downsampled = rescale(label_onehot, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                       order=1,
                                       preserve_range=True,
                                       multichannel=True,
                                       mode='constant')
    label_onehot_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        label_onehot_downsampled, int(256 / exp_config.downsampling_factor_x))

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1])
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[exp_config.batch_size] +
                                   list(exp_config.image_size) + [1],
                                   name='images')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl, exp_config, training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # By setting the 'training_pl' to false directly, the update ops for the moments in the BN layer are not created at all.
        # This allows grouping the update ops together with the optimizer training, while training the normalizer - in case the normalizer has BN.
        # ================================================================
        predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l(
            images_normalized,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # 3d prior
        # ================================================================
        labels_3d_1hot_shape = [1] + list(
            exp_config.image_size_downsampled) + [exp_config.nlabels]
        # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder
        predicted_seg_1hot_3d_pl = tf.placeholder(tf.float32,
                                                  shape=labels_3d_1hot_shape,
                                                  name='predicted_labels_3d')

        # denoise the noisy segmentation
        _, predicted_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l(
            predicted_seg_1hot_3d_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        l2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'l2l_mapper' in var_name:
                l2l_vars.append(v)

        # ================================================================
        # Make a list of trainable i2l vars. This will be used to compute gradients.
        # The other list contains trainable as well as non-trainable parameters. This is required for saving and loading all parameters, but runs into trouble when gradients are asked to be computed for non-trainable parameters
        # ================================================================
        i2l_vars_trainable = []
        for v in i2l_vars:
            if v.trainable is True:
                i2l_vars_trainable.append(v)

        # ================================================================
        # add ops for calculation of the prior loss - wrt an atlas or the outputs of the DAE
        # ================================================================
        prior_label_1hot_pl = tf.placeholder(
            tf.float32,
            shape=[exp_config.batch_size_downsampled] + list(
                (exp_config.image_size_downsampled[1],
                 exp_config.image_size_downsampled[2])) + [exp_config.nlabels],
            name='labels_prior')

        # down sample the predicted logits
        predicted_seg_logits_expanded = tf.expand_dims(predicted_seg_logits,
                                                       axis=0)
        # the 'upsample' function will actually downsample the predictions, as the scaling factors have been set appropriately
        predicted_seg_logits_downsampled = layers.bilinear_upsample3D_(
            predicted_seg_logits_expanded,
            name='downsampled_predictions',
            factor_x=1 / exp_config.downsampling_factor_x,
            factor_y=1 / exp_config.downsampling_factor_y,
            factor_z=1 / exp_config.downsampling_factor_z)
        predicted_seg_logits_downsampled = tf.squeeze(
            predicted_seg_logits_downsampled
        )  # the first axis was added only for the downsampling in 3d

        # compute the dice between the predictions and the prior in the downsampled space
        loss_op = model.loss(logits=predicted_seg_logits_downsampled,
                             labels=prior_label_1hot_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_prior,
                             mask_for_loss_within_mask=None,
                             are_labels_1hot=True)

        tf.summary.scalar('tr_losses/loss', loss_op)

        # ================================================================
        # one of the two prior losses will be used in the following manner:
        # the atlas prior will be used when the current prediction is deemed to be very far away from a reasonable solution
        # once a reasonable solution is reached, the dae prior will be used.
        # these 3d computations will be done outside the graph and will be passed via placeholders for logging in tensorboard
        # ================================================================
        lambda_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='lambda_prior_atlas')
        lambda_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='lambda_prior_dae')
        tf.summary.scalar('lambdas/prior_atlas', lambda_prior_atlas_pl)
        tf.summary.scalar('lambdas/prior_dae', lambda_prior_dae_pl)

        dice3d_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='dice3d_prior_atlas')
        dice3d_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='dice3d_prior_dae')
        dice3d_gt_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_gt')
        tf.summary.scalar('dice3d/prior_atlas', dice3d_prior_atlas_pl)
        tf.summary.scalar('dice3d/prior_dae', dice3d_prior_dae_pl)
        tf.summary.scalar('dice3d/gt', dice3d_gt_pl)

        # ================================================================
        # add optimization ops
        # ================================================================
        if exp_config.debug: print('creating training op...')

        # create an instance of the required optimizer
        optimizer = exp_config.optimizer_handle(learning_rate=learning_rate_pl)

        # initialize variable holding the accumlated gradients and create a zero-initialisation op
        accumulated_gradients = [
            tf.Variable(tf.zeros_like(var.initialized_value()),
                        trainable=False) for var in i2l_vars_trainable
        ]

        # accumulated gradients init op
        accumulated_gradients_zero_op = [
            ac.assign(tf.zeros_like(ac)) for ac in accumulated_gradients
        ]

        # calculate gradients and define accumulation op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            gradients = optimizer.compute_gradients(
                loss_op, var_list=i2l_vars_trainable)
            # compute_gradients return a list of (gradient, variable) pairs.
        accumulate_gradients_op = [
            ac.assign_add(gg[0])
            for ac, gg in zip(accumulated_gradients, gradients)
        ]

        # define the gradient mean op
        num_accumulation_steps_pl = tf.placeholder(
            dtype=tf.float32, name='num_accumulation_steps')
        accumulated_gradients_mean_op = [
            ag.assign(tf.divide(ag, num_accumulation_steps_pl))
            for ag in accumulated_gradients
        ]

        # reassemble the gradients in the [value, var] format and do define train op
        final_gradients = [(ag, gg[1])
                           for ag, gg in zip(accumulated_gradients, gradients)]
        train_op = optimizer.apply_gradients(final_gradients)

        # ================================================================
        # sequence of running opt ops:
        # 1. at the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
        # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
        # 3. at the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
        # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
        # ================================================================
        # ================================================================
        # previous train_op without accumulation of gradients
        # ================================================================
        # train_op = model.training_step(loss_op, normalization_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars = True)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the training errors
        # ================================================================
        prior_dae_dice = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='prior_dae_dice')
        prior_dae_dice_summary = tf.summary.scalar('test_img/prior_dae_dice',
                                                   prior_dae_dice)
        prior_dae_output_dice_wrt_gt = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_output_dice_wrt_gt')
        prior_dae_output_dice_wrt_gt_summary = tf.summary.scalar(
            'test_img/prior_dae_output_dice_wrt_gt',
            prior_dae_output_dice_wrt_gt)
        prior_atlas_dice = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='prior_atlas_dice')
        prior_atlas_dice_summary = tf.summary.scalar(
            'test_img/prior_atlas_dice', prior_atlas_dice)
        prior_dae_atlas_dice_ratio = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_atlas_dice_ratio')
        prior_dae_atlas_dice_ratio_summary = tf.summary.scalar(
            'test_img/prior_dae_atlas_dice_ratio', prior_dae_atlas_dice_ratio)
        prior_dice = tf.placeholder(tf.float32, shape=[], name='prior_dice')
        prior_dice_summary = tf.summary.scalar('test_img/prior_dice',
                                               prior_dice)
        gt_dice = tf.placeholder(tf.float32, shape=[], name='gt_dice')
        gt_dice_summary = tf.summary.scalar('test_img/gt_dice', gt_dice)

        # ================================================================
        # create savers
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_l2l = tf.train.Saver(var_list=l2l_vars)
        saver_test_data = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)
        saver_best_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # add operations to compute dice between two 3d volumes
        # ================================================================
        pred_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='pred_3d')
        labl_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='labl_3d')
        atls_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='atls_3d')

        dice_3d_op_dae = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=labl_3d_1hot_pl)
        dice_3d_op_atlas = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=atls_3d_1hot_pl)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the prior network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_l2l.restore(sess, checkpoint_path)

        # ================================================================
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        if log_dir_first_TD_subject is not '':
            logging.info(
                '============================================================')
            path_to_model = log_dir_first_TD_subject + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_score.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_test_data.restore(sess, checkpoint_path)
            max_steps_tta = exp_config.max_steps
        else:
            max_steps_tta = 5 * exp_config.max_steps  # run the adaptation of the 1st TD subject for longer

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring normalization module from: %s' %
                         init_checkpoint_path)
            saver_test_data.restore(sess, init_checkpoint_path)

        # ================================================================
        # run training epochs
        # ================================================================
        step = init_step
        best_score = 0.0

        while (step < max_steps_tta):

            # ================================================
            # After every some epochs,
            # Get the prediction for the entire volume, evaluate it using the DAE.
            # Now, decide whether to use the DAE output or the atlas as the ground truth for the next update.
            # ================================================
            if (step == init_step) or (step % exp_config.check_ood_frequency is
                                       0):

                # ==================
                # 1. compute the current 3d segmentation prediction
                # ==================
                y_pred_soft = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred_soft.append(
                        sess.run(predicted_seg_softmax,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred_soft = np.squeeze(np.array(y_pred_soft)).astype(float)
                y_pred_soft = np.reshape(y_pred_soft, [
                    -1, y_pred_soft.shape[2], y_pred_soft.shape[3],
                    y_pred_soft.shape[4]
                ])

                # ==================
                # 2. downsample it. Let's call this guy 'A'
                # ==================
                y_pred_soft_downsampled = rescale(y_pred_soft, [
                    1 / exp_config.downsampling_factor_x,
                    1 / exp_config.downsampling_factor_y,
                    1 / exp_config.downsampling_factor_z
                ],
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant').astype(
                                                      np.float32)

                # ==================
                # 3. pass the downsampled prediction through the DAE and get its output 'B'
                # ==================
                feed_dict = {
                    predicted_seg_1hot_3d_pl:
                    np.expand_dims(y_pred_soft_downsampled, axis=0)
                }
                y_pred_noisy_denoised_softmax = np.squeeze(
                    sess.run(
                        predicted_seg_softmax_3d_noisy_autoencoded_softmax,
                        feed_dict=feed_dict)).astype(np.float16)
                y_pred_noisy_denoised = np.argmax(
                    y_pred_noisy_denoised_softmax, axis=-1)

                # ==================
                # 4. compute the dice between:
                #       a. 'A' (seg network prediction downsampled) and 'B' (dae network output)
                #       b. 'B' (dae network output) and downsampled gt labels (for debugging, to see if the dae output is close to the gt.)
                #       c. 'A' (seg network prediction downsampled) and 'C' (downsampled atlas)
                # ==================
                dAB = sess.run(dice_3d_op_dae,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   labl_3d_1hot_pl:
                                   y_pred_noisy_denoised_softmax
                               })
                dBgt = sess.run(dice_3d_op_dae,
                                feed_dict={
                                    pred_3d_1hot_pl:
                                    y_pred_noisy_denoised_softmax,
                                    labl_3d_1hot_pl: label_onehot_downsampled
                                })
                dAC = sess.run(dice_3d_op_atlas,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   atls_3d_1hot_pl: atlas_downsampled
                               })

                # ==================
                # 5. compute the ratio dice(AB) / dice(AC). pass the ratio through a threshold and decide whether to use the DAE or the atlas as the prior
                # ==================
                ratio_dice = dAB / (dAC + 1e-5)

                if exp_config.use_gt_for_tta is True:
                    target_labels_for_this_epoch = label_onehot_downsampled
                    prr = dBgt
                elif (ratio_dice > exp_config.dae_atlas_ratio_threshold) and (
                        dAC > exp_config.min_atlas_dice):
                    target_labels_for_this_epoch = y_pred_noisy_denoised_softmax
                    prr = dAB
                else:
                    target_labels_for_this_epoch = atlas_downsampled
                    prr = dAC

                # ==================
                # update losses on tensorboard
                # ==================
                summary_writer.add_summary(
                    sess.run(prior_dae_dice_summary,
                             feed_dict={prior_dae_dice: dAB}), step)
                summary_writer.add_summary(
                    sess.run(prior_dae_output_dice_wrt_gt_summary,
                             feed_dict={prior_dae_output_dice_wrt_gt: dBgt}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_atlas_dice_summary,
                             feed_dict={prior_atlas_dice: dAC}), step)
                summary_writer.add_summary(
                    sess.run(
                        prior_dae_atlas_dice_ratio_summary,
                        feed_dict={prior_dae_atlas_dice_ratio: ratio_dice}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_dice_summary, feed_dict={prior_dice: prr}),
                    step)

                # ==================
                # save best model so far
                # ==================
                if best_score < prr:
                    best_score = prr
                    best_file = os.path.join(log_dir, 'models/best_score.ckpt')
                    saver_best_loss.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best score (%f) at step %d -  Saving model.'
                        % (best_score, step))

                # ==================
                # dice wrt gt
                # ==================
                y_pred = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred.append(
                        sess.run(predicted_seg,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred = np.squeeze(np.array(y_pred)).astype(float)
                y_pred = np.reshape(y_pred,
                                    [-1, y_pred.shape[2], y_pred.shape[3]])
                dice_wrt_gt = met.f1_score(label.flatten(),
                                           y_pred.flatten(),
                                           average=None)
                summary_writer.add_summary(
                    sess.run(gt_dice_summary,
                             feed_dict={gt_dice: np.mean(dice_wrt_gt[1:])}),
                    step)

                # ==================
                # visualize results
                # ==================
                if step % exp_config.vis_frequency is 0:

                    # ===========================
                    # save checkpoint
                    # ===========================
                    logging.info(
                        '=============== Saving checkkpoint at step %d ... ' %
                        step)
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver_test_data.save(sess,
                                         checkpoint_file,
                                         global_step=step)

                    y_pred_noisy_denoised_upscaled = utils.crop_or_pad_volume_to_size_along_x(
                        rescale(y_pred_noisy_denoised, [
                            exp_config.downsampling_factor_x,
                            exp_config.downsampling_factor_y,
                            exp_config.downsampling_factor_z
                        ],
                                order=0,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant'),
                        image.shape[0]).astype(np.uint8)

                    x_norm = []
                    for batch in iterate_minibatches_images(
                            image, batch_size=exp_config.batch_size):
                        x = batch
                        x_norm.append(
                            sess.run(images_normalized,
                                     feed_dict={
                                         images_pl: x,
                                         training_pl: False
                                     }))
                    x_norm = np.squeeze(np.array(x_norm)).astype(float)
                    x_norm = np.reshape(x_norm,
                                        [-1, x_norm.shape[2], x_norm.shape[3]])

                    utils_vis.save_sample_results(
                        x=image,
                        x_norm=x_norm,
                        x_diff=x_norm - image,
                        y=y_pred,
                        y_pred_dae=y_pred_noisy_denoised_upscaled,
                        at=np.argmax(atlas, axis=-1),
                        gt=label,
                        savepath=log_dir + '/results/visualize_images/step' +
                        str(step) + '.png')

            # ================================================
            # Part of training ops sequence:
            # 1. At the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
            # ================================================
            sess.run(accumulated_gradients_zero_op)
            num_accumulation_steps = 0

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches_images_and_downsampled_labels(
                    images=image,
                    batch_size=exp_config.batch_size,
                    labels_downsampled=target_labels_for_this_epoch,
                    batch_size_downsampled=exp_config.batch_size_downsampled):

                x, y = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    prior_label_1hot_pl: y,
                    learning_rate_pl: exp_config.learning_rate,
                    training_pl: True
                }

                # ================================================
                # Part of training ops sequence:
                # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
                # ================================================
                sess.run(accumulate_gradients_op, feed_dict=feed_dict)
                num_accumulation_steps = num_accumulation_steps + 1

                step += 1

            # ================================================
            # Part of training ops sequence:
            # 3. At the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
            # ================================================
            sess.run(
                accumulated_gradients_mean_op,
                feed_dict={num_accumulation_steps_pl: num_accumulation_steps})

            # ================================================================
            # sequence of running opt ops:
            # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
            # ================================================================
            sess.run(train_op, feed_dict=feed_dict)

        # ================================================================
        # ================================================================
        sess.close()

    # ================================================================
    # ================================================================
    gc.collect()

    return 0
def predict_segmentation(subject_name,
                         image,
                         normalize=True,
                         post_process=False):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + list(exp_config.image_size) +
                                   [1],
                                   name='images')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l(
            images_normalized,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # 3d prior
        # ================================================================
        labels_3d_shape = [1] + list(exp_config.image_size_downsampled)
        # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder
        predicted_seg_3d_pl = tf.placeholder(tf.uint8,
                                             shape=labels_3d_shape,
                                             name='true_labels_3d')
        predicted_seg_1hot_3d_pl = tf.one_hot(predicted_seg_3d_pl,
                                              depth=exp_config.nlabels)

        # denoise the noisy segmentation
        _, pred_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l(
            predicted_seg_1hot_3d_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        l2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'l2l_mapper' in var_name:
                l2l_vars.append(v)

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_normalizer = tf.train.Saver(var_list=normalization_vars)
        saver_l2l = tf.train.Saver(var_list=l2l_vars)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the prior network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_l2l.restore(sess, checkpoint_path)

        # ================================================================
        # Make predictions for the image at the resolution of the image after pre-processing
        # ================================================================
        mask_predicted = []
        mask_predicted_soft = []
        img_normalized = []

        for b_i in range(0, image.shape[0], 1):

            X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1)

            mask_predicted.append(
                sess.run(predicted_seg, feed_dict={images_pl: X}))
            mask_predicted_soft.append(
                sess.run(predicted_seg_softmax, feed_dict={images_pl: X}))
            img_normalized.append(
                sess.run(images_normalized, feed_dict={images_pl: X}))

        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)
        mask_predicted_soft = np.squeeze(
            np.array(mask_predicted_soft)).astype(float)
        img_normalized = np.squeeze(np.array(img_normalized)).astype(float)

        # ================================================================
        # downsample predicted mask and pass it through the DAE
        # ================================================================
        if post_process is True:

            # downsample mask_predicted_soft
            mask_predicted_soft_downsampled = rescale(mask_predicted_soft, [
                1 / exp_config.downsampling_factor_x,
                1 / exp_config.downsampling_factor_y,
                1 / exp_config.downsampling_factor_z
            ],
                                                      order=1,
                                                      preserve_range=True,
                                                      multichannel=True,
                                                      mode='constant')
            mask_predicted_downsampled = np.argmax(
                mask_predicted_soft_downsampled, axis=-1)

            # pass the downsampled prediction through the DAE
            mask_predicted_denoised = mask_predicted_downsampled
            for _ in range(exp_config.dae_post_process_runs):
                feed_dict = {
                    predicted_seg_3d_pl:
                    np.expand_dims(mask_predicted_denoised, axis=0)
                }
                y_pred_noisy_denoised_softmax = np.squeeze(
                    sess.run(pred_seg_softmax_3d_noisy_autoencoded_softmax,
                             feed_dict=feed_dict)).astype(np.float16)
                mask_predicted_denoised = np.argmax(
                    y_pred_noisy_denoised_softmax, axis=-1)

            # upsample the denoised prediction
            mask_predicted = rescale(mask_predicted_denoised, [
                exp_config.downsampling_factor_x,
                exp_config.downsampling_factor_y,
                exp_config.downsampling_factor_z
            ],
                                     order=0,
                                     preserve_range=True,
                                     multichannel=False,
                                     mode='constant').astype(np.uint8)

        sess.close()

        return mask_predicted, img_normalized
Пример #4
0
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    if exp_config.train_dataset is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=1040,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    if exp_config.train_dataset is 'HCPT2':
        logging.info('Reading HCPT2 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=20,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'CALTECH':
        logging.info('Reading CALTECH images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'CALTECH/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'STANFORD':
        logging.info('Reading STANFORD images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'STANFORD/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'IXI':
        logging.info('Reading IXI images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_ixi)
        data_brain_train = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=0,
            idx_end=12,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=12,
            idx_end=17,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    logging.info(
        'Training Images: %s' %
        str(imtr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info(
        'Training Labels: %s' %
        str(gttr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info('Validation Images: %s' % str(imvl.shape))
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)
        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module will be adapted for each test image
        # ================================================================
        images_normalized, _ = model.normalize(images_pl, exp_config,
                                               training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, _, _ = model.predict_i2l(images_normalized,
                                         exp_config,
                                         training_pl=training_pl)

        print('shape of inputs: ',
              images_pl.shape)  # (batch_size, 256, 256, 1)
        print('shape of logits: ',
              logits.shape)  # (batch_size, 256, 256, nlabels)

        # ================================================================
        # create a list of all vars that must be optimized wrt
        # ================================================================
        i2l_vars = []
        for v in tf.trainable_variables():
            i2l_vars.append(v)

        # ================================================================
        # add ops for calculation of the supervised training loss
        # ================================================================
        loss_op = model.loss(logits,
                             labels_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('loss', loss_op)

        # ================================================================
        # add optimization ops.
        # Create different ops according to the variables that must be trained
        # ================================================================
        print('creating training op...')
        train_op = model.training_step(loss_op,
                                       i2l_vars,
                                       exp_config.optimizer_handle,
                                       learning_rate_pl,
                                       update_bn_nontrainable_vars=True)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        print('creating eval op...')
        eval_loss = model.evaluation_i2l(logits,
                                         labels_pl,
                                         images_pl,
                                         nlabels=exp_config.nlabels,
                                         loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        print('creating summary op...')
        summary = tf.summary.merge_all()

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create saver
        # ================================================================
        saver = tf.train.Saver(max_to_keep=10)
        saver_best_dice = tf.train.Saver(max_to_keep=3)

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error')
        vl_error_summary = tf.summary.scalar('validation/loss', vl_error)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error')
        tr_error_summary = tf.summary.scalar('training/loss', tr_error)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of all variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of all variables:')
        for v in tf.trainable_variables():
            print(v.name)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            print(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_dice = 0

        # ================================================================
        # run training epochs
        # ================================================================
        while (step < exp_config.max_steps):

            if step % 1000 is 0:
                logging.info(
                    '============================================================'
                )
                logging.info('step %d' % step)

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(imtr,
                                             gttr,
                                             batch_size=exp_config.batch_size,
                                             train_or_eval='train'):

                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # ===========================
                # create the feed dict for this training iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # opt step
                # ===========================
                _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.3f (%.3f sec for the last step)' %
                        (step + 1, loss, duration))

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('Training Data Eval:')
                    train_loss, train_dice = do_eval(sess, eval_loss,
                                                     images_pl, labels_pl,
                                                     training_pl, imtr, gttr,
                                                     exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error: train_loss,
                                                  tr_dice: train_dice
                                              })

                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('Validation Data Eval:')
                    val_loss, val_dice = do_eval(sess, eval_loss, images_pl,
                                                 labels_pl, training_pl, imvl,
                                                 gtvl, exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error: val_loss,
                                                  vl_dice: val_dice
                                              })

                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    if val_dice > best_dice:
                        best_dice = val_dice
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_dice.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice)

                step += 1

        sess.close()
Пример #5
0
def predict_segmentation(subject_name, image, normalize=True):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + list(exp_config.image_size) +
                                   [1],
                                   name='images')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, softmax, preds = model.predict_i2l(images_normalized,
                                                   exp_config,
                                                   training_pl=tf.constant(
                                                       False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            i2l_vars.append(v)
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_normalizer = tf.train.Saver(var_list=normalization_vars)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the normalization network parameters
        # ================================================================
        if normalize is True:
            logging.info(
                '============================================================')
            path_to_model = os.path.join(
                sys_config.log_root, exp_config.expname_normalizer
            ) + '/subject_' + subject_name + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_score.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_normalizer.restore(sess, checkpoint_path)
            logging.info(
                '============================================================')

        # ================================================================
        # Make predictions for the image at the resolution of the image after pre-processing
        # ================================================================
        mask_predicted = []
        img_normalized = []

        for b_i in range(0, image.shape[0], 1):

            X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1)

            mask_predicted.append(sess.run(preds, feed_dict={images_pl: X}))
            img_normalized.append(
                sess.run(images_normalized, feed_dict={images_pl: X}))

        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)
        img_normalized = np.squeeze(np.array(img_normalized)).astype(float)

        sess.close()

        return mask_predicted, img_normalized
Пример #6
0
def predict_segmentation(image):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + [exp_config.image_size[0]] +
                                   [exp_config.image_size[1]] + [1],
                                   name='images')

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, softmax, preds = model.predict_i2l(images_pl,
                                                   exp_config,
                                                   training_pl=tf.constant(
                                                       False, dtype=tf.bool))

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver()

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # predict segmentation
        # ================================================================
        X = np.expand_dims(np.expand_dims(image, axis=-1), axis=0)
        mask_predicted = sess.run(preds, feed_dict={images_pl: X})
        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)

        sess.close()

        return mask_predicted