Пример #1
0
def get_latest_checkpoint_and_log(logdir, filename):
    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
        logdir, filename)
    logging.info('Checkpoint path: %s' % init_checkpoint_path)
    last_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1])
    logging.info('Latest step was: %d' % last_step)
    return init_checkpoint_path
def main(fs_exp_config, slices, test):
    # Load data
    data = load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=fs_exp_config.data_mode,
        size=fs_exp_config.image_size,
        target_resolution=fs_exp_config.target_resolution,
        force_overwrite=False
    )
    # Get images
    batch_size = len(slices)
    if test:
        slices = slices[slices < len(data['images_test'])]
        images = data['images_test'][slices, ...]
        prefix = 'test'
    else:
        slices = slices[slices < len(data['images_train'])]
        images = data['images_train'][slices, ...]
        prefix = 'train'

    image_tensor_shape = [batch_size] + list(fs_exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }

    #Get full supervision prediction
    mask_pl, softmax_pl = model.predict(images_pl, fs_exp_config.model_handle, fs_exp_config.nlabels)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        checkpoint_path = utils.get_latest_model_checkpoint_path(fs_model_path,
                                                                 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)
        fs_predictions, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

    for i in range(batch_size):
        print_coloured(fs_predictions[i, ...],  filepath=OUTPUT_FOLDER, filename='{}{}_fs_pred'.format(prefix, slices[i]))
Пример #3
0
def run_training(continue_run, log_dir):

    logging.info('===== RUNNING EXPERIMENT ========')
    logging.info(exp_config.experiment_name)
    logging.info('=================================')

    init_step = 0

    if continue_run:
        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
            log_dir += '_cont'

        except:
            logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...')
            continue_run = False
            init_step = 0

        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')

    # import data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=exp_config.data_root,
        preprocessing_folder=exp_config.preproc_folder,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        label_list = exp_config.label_list,
        offset=exp_config.offset,
        rescale_to_one=exp_config.rescale_to_one,
        force_overwrite=False
    )

    # extract images and indices of source/target images for the training and validation set
    images_train, source_images_train_ind, target_images_train_ind,\
    images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices(
        data, exp_config.source_field_strength, exp_config.target_field_strength)

    generator = exp_config.generator
    discriminator = exp_config.discriminator

    z_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=source_images_train_ind)
    x_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=target_images_train_ind)


    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        im_s = exp_config.image_size

        training_placeholder = tf.placeholder(tf.bool, name='training_phase')

        if exp_config.use_generator_input_noise:
            noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1)
        else:
            noise_in_gen_pl = None

        # target image batch
        x_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='x')

        # source image batch
        z_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='z')

        # generated fake image batch
        x_pl_ = generator(z_pl, noise_in_gen_pl, training_placeholder)

        # difference between generated and source images
        diff_img_pl = x_pl_ - z_pl

        # visualize the images by showing one slice of them in the z direction
        tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(x_pl_, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(x_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_zs', tf_utils.put_kernels_on_grid3d(z_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_difference_gx-x', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='centered',
                                                                          cutoff_abs=exp_config.diff_threshold))

        # output of the discriminator for real image
        d_pl = discriminator(x_pl, training_placeholder, scope_reuse=False)

        # output of the discriminator for fake image
        d_pl_ = discriminator(x_pl_, training_placeholder, scope_reuse=True)

        d_hat = None
        x_hat = None
        if exp_config.improved_training:

            epsilon = tf.random_uniform([], 0.0, 1.0)
            x_hat = epsilon * x_pl + (1 - epsilon) * x_pl_
            d_hat = discriminator(x_hat, training_placeholder, scope_reuse=True)

        dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl))

        # nr means no regularization, meaning the loss without the regularization term
        discriminator_train_op, generator_train_op, \
        disc_loss_pl, gen_loss_pl, \
        disc_loss_nr_pl, gen_loss_nr_pl = gan_model.training_ops(d_pl, d_pl_,
                                                                 optimizer_handle=exp_config.optimizer_handle,
                                                                 learning_rate=exp_config.learning_rate,
                                                                 l1_img_dist=dist_l1,
                                                                 w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1,
                                                                 w_reg_gen_l1=exp_config.w_reg_gen_l1,
                                                                 w_reg_disc_l1=exp_config.w_reg_disc_l1,
                                                                 w_reg_gen_l2=exp_config.w_reg_gen_l2,
                                                                 w_reg_disc_l2=exp_config.w_reg_disc_l2,
                                                                 d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale)


        # Build the operation for clipping the discriminator weights
        d_clip_op = gan_model.clip_op()

        # Put L1 distance of generated image and original image on summary
        dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # validation summaries
        val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss')
        disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl)

        val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss')
        gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl)

        val_summary_op = tf.summary.merge([disc_val_summary_op, gen_val_summary_op])

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a savers for writing training checkpoints.
        saver_latest = tf.train.Saver(max_to_keep=3)
        saver_best_disc = tf.train.Saver(max_to_keep=3)  # disc loss is scaled negative EM distance

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        # Create a session for running Ops on the Graph.
        sess = tf.Session(config=config)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver_latest.restore(sess, init_checkpoint_path)


        # initialize value of lowest (i. e. best) discriminator loss
        best_d_loss = np.inf

        for step in range(init_step, 1000000):

            start_time = time.time()

            # discriminator training iterations
            d_iters = 5
            if step % 500 == 0 or step < 25:
                d_iters = 100

            for _ in range(d_iters):

                x = next(x_sampler_train)
                z = next(z_sampler_train)

                # train discriminator
                sess.run(discriminator_train_op,
                         feed_dict={z_pl: z, x_pl: x, training_placeholder: True})

                if not exp_config.improved_training:
                    sess.run(d_clip_op)

            elapsed_time = time.time() - start_time

            # train generator
            x = next(x_sampler_train)  # why not sample a new x??
            z = next(z_sampler_train)
            sess.run(generator_train_op,
                     feed_dict={z_pl: z, x_pl: x, training_placeholder: True})

            if step % exp_config.update_tensorboard_frequency == 0:

                x = next(x_sampler_train)
                z = next(z_sampler_train)

                g_loss_train, d_loss_train, summary_str = sess.run(
                        [gen_loss_nr_pl, disc_loss_nr_pl, summary_op], feed_dict={z_pl: z, x_pl: x, training_placeholder: False})

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                logging.info("[Step: %d], generator loss: %g, discriminator_loss: %g" % (step, g_loss_train, d_loss_train))
                logging.info(" - elapsed time for one step: %f secs" % elapsed_time)


            if step % exp_config.validation_frequency == 0:

                z_sampler_val = iterate_minibatches_endlessly(images_val,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=source_images_val_ind)
                x_sampler_val = iterate_minibatches_endlessly(images_val,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=target_images_val_ind)

                # evaluate the validation batch with batch_size images (from each domain) at a time
                g_loss_val_list = []
                d_loss_val_list = []
                for _ in range(exp_config.num_val_batches):
                    x = next(x_sampler_val)
                    z = next(z_sampler_val)
                    g_loss_val, d_loss_val = sess.run(
                        [gen_loss_nr_pl, disc_loss_nr_pl], feed_dict={z_pl: z,
                                                                      x_pl: x,
                                                                      training_placeholder: False})
                    g_loss_val_list.append(g_loss_val)
                    d_loss_val_list.append(d_loss_val)

                g_loss_val_avg = np.mean(g_loss_val_list)
                d_loss_val_avg = np.mean(d_loss_val_list)

                validation_summary_str = sess.run(val_summary_op, feed_dict={val_disc_loss_pl: d_loss_val_avg,
                                                                                 val_gen_loss_pl: g_loss_val_avg}
                                             )
                summary_writer.add_summary(validation_summary_str, step)
                summary_writer.flush()

                # save best variables (if discriminator loss is the lowest yet)
                if d_loss_val_avg <= best_d_loss:
                    best_d_loss = d_loss_val_avg
                    best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt')
                    saver_best_disc.save(sess, best_file, global_step=step)
                    logging.info('Found new best discriminator loss on validation set! - %f -  Saving model_best_d_loss.ckpt' % best_d_loss)

                logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg))

            # Write the summaries and print an overview fairly often.
            if step % exp_config.save_frequency == 0:

                saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step)
Пример #4
0
def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl,
                     labels_sd_vl, images_td_tr, images_td_vl):

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_num_uda)
        np.random.seed(exp_config.run_num_uda)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_sd_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_sd')
        images_td_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_td')
        labels_sd_pl = tf.placeholder(tf.uint8,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size),
                                      name='labels_sd')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # ================================================================
        images_sd_normalized, _ = model.normalize(images_sd_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=False)
        images_td_normalized, _ = model.normalize(images_td_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=True)

        # ================================================================
        # get logit predictions from the segmentation network
        # ================================================================
        predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized,
                                                          exp_config,
                                                          training_pl,
                                                          scope_reuse=False)

        # ================================================================
        # get all features from the segmentation network
        # ================================================================
        seg_features_sd = model.get_all_features(images_sd_normalized,
                                                 exp_config,
                                                 scope_reuse=True)
        seg_features_td = model.get_all_features(images_td_normalized,
                                                 exp_config,
                                                 scope_reuse=True)

        # ================================================================
        # resize all features to the same size
        # ================================================================
        images_sd_features_resized = model.resize_features(
            [images_sd_normalized] + list(seg_features_sd), (64, 64),
            'resize_sd_xfeat')
        images_td_features_resized = model.resize_features(
            [images_td_normalized] + list(seg_features_td), (64, 64),
            'resize_td_xfeat')

        # ================================================================
        # discriminator on features
        # ================================================================
        d_logits_sd = model.discriminator(images_sd_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=False)
        d_logits_td = model.discriminator(images_td_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=True)

        # ================================================================
        # add ops for calculation of the discriminator loss
        # ================================================================
        d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(d_logits_sd), logits=d_logits_sd)
        d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(d_logits_td), logits=d_logits_td)
        loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td)
        tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op)

        # ================================================================
        # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space
        # ================================================================
        loss_g_op = model.loss_invariance(d_logits_td)
        tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op)

        # ================================================================
        # add ops for calculation of the supervised segmentation loss
        # ================================================================
        loss_seg_op = model.loss(predicted_seg_sd_logits,
                                 labels_sd_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op)

        # ================================================================
        # total training loss for uda
        # ================================================================
        loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op
        tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op)

        # ================================================================
        # merge all summaries
        # ================================================================
        summary_scalars = tf.summary.merge_all()

        # ================================================================
        # divide the vars into segmentation network, normalization network and the discriminator network
        # ================================================================
        i2l_vars = []
        normalization_vars = []
        discriminator_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'discriminator' in var_name:
                discriminator_vars.append(v)

        # ================================================================
        # add optimization ops
        # ================================================================
        train_i2l_op = model.training_step(
            loss_total_op,
            i2l_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        train_discriminator_op = model.training_step(
            loss_d_op,
            discriminator_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        eval_loss = model.evaluation_i2l_uda_invariant_features(
            predicted_seg_sd_logits,
            labels_sd_pl,
            images_sd_pl,
            d_logits_td,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # add ops for adding image summary to tensorboard
        # ================================================================
        summary_images = model.write_image_summary_uda_invariant_features(
            predicted_seg_sd_logits, labels_sd_pl, images_sd_pl,
            exp_config.nlabels)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a file writer object
        # This writes Summary protocol buffers to event files.
        # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md
        # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it.
        # The class updates the file contents asynchronously.
        # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # create savers
        # ================================================================
        saver = tf.train.Saver(var_list=i2l_vars)
        saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='vl_error_seg')
        vl_error_seg_summary = tf.summary.scalar('validation/loss_seg',
                                                 vl_error_seg)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='vl_error_invariance')
        vl_error_invariance_summary = tf.summary.scalar(
            'validation/loss_invariance', vl_error_invariance)
        vl_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='vl_error_total')
        vl_error_total_summary = tf.summary.scalar('validation/loss_total',
                                                   vl_error_total)
        vl_summary = tf.summary.merge([
            vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary,
            vl_error_total_summary
        ])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='tr_error_seg')
        tr_error_seg_summary = tf.summary.scalar('training/loss_seg',
                                                 tr_error_seg)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='tr_error_invariance')
        tr_error_invariance_summary = tf.summary.scalar(
            'training/loss_invariance', tr_error_invariance)
        tr_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='tr_error_total')
        tr_error_total_summary = tf.summary.scalar('training/loss_total',
                                                   tr_error_total)
        tr_summary = tf.summary.merge([
            tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary,
            tr_error_total_summary
        ])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # ================================================================
        if exp_config.train_from_scratch is False:
            logging.info(
                '============================================================')
            path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_dice.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_lowest_loss.restore(sess, checkpoint_path)

        # ================================================================
        # run training steps
        # ================================================================
        step = 0
        lowest_loss = 10000.0
        validation_total_loss_list = []

        while (step < exp_config.max_steps):

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(images_sd=images_sd_tr,
                                             labels_sd=labels_sd_tr,
                                             images_td=images_td_tr,
                                             batch_size=exp_config.batch_size):

                x_sd, y_sd, x_td = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_sd_pl: x_sd,
                    labels_sd_pl: y_sd,
                    images_td_pl: x_td,
                    training_pl: True
                }

                # ================================================
                # update i2l and D successively
                # ================================================
                sess.run(train_i2l_op, feed_dict=feed_dict)
                sess.run(train_discriminator_op, feed_dict=feed_dict)

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        '============== Updating summary at step %d ' % step)
                    summary_writer.add_summary(
                        sess.run(summary_scalars, feed_dict=feed_dict), step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('============== Training Data Eval:')
                    train_loss_seg, train_dice, train_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_tr, labels_sd_tr,
                        images_td_tr, exp_config.batch_size)

                    # total training loss
                    train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error_seg: train_loss_seg,
                                                  tr_dice: train_dice,
                                                  tr_error_invariance:
                                                  train_loss_invariance,
                                                  tr_error_total:
                                                  train_total_loss
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    logging.info(
                        '============== Periodically saving checkpoint:')
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('============== Validation Data Eval:')
                    val_loss_seg, val_dice, val_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_vl, labels_sd_vl,
                        images_td_vl, exp_config.batch_size)

                    # total val loss
                    val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance
                    validation_total_loss_list.append(val_total_loss)

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error_seg: val_loss_seg,
                                                  vl_dice: val_dice,
                                                  vl_error_invariance:
                                                  val_loss_invariance,
                                                  vl_error_total:
                                                  val_total_loss
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # update tensorboard summary of images
                    # ===========================
                    summary_writer.add_summary(
                        sess.run(summary_images,
                                 feed_dict={
                                     images_sd_pl: x_sd,
                                     labels_sd_pl: y_sd,
                                     training_pl: False
                                 }), step)
                    summary_writer.flush()

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    window_length = 5
                    if len(validation_total_loss_list) < window_length + 1:
                        expo_moving_avg_loss_value = validation_total_loss_list[
                            -1]
                    else:
                        expo_moving_avg_loss_value = utils.exponential_moving_average(
                            validation_total_loss_list,
                            window=window_length)[-1]

                    if expo_moving_avg_loss_value < lowest_loss:
                        lowest_loss = val_total_loss
                        lowest_loss_file = os.path.join(
                            log_dir, 'models/lowest_loss.ckpt')
                        saver_lowest_loss.save(sess,
                                               lowest_loss_file,
                                               global_step=step)
                        logging.info(
                            '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********'
                            % step)

                # ================================================
                # increment step
                # ================================================
                step += 1

        # ================================================================
        # close tf session
        # ================================================================
        sess.close()

    return 0
def predict_segmentation(subject_name,
                         image,
                         normalize=True,
                         post_process=False):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + list(exp_config.image_size) +
                                   [1],
                                   name='images')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l(
            images_normalized,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # 3d prior
        # ================================================================
        labels_3d_shape = [1] + list(exp_config.image_size_downsampled)
        # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder
        predicted_seg_3d_pl = tf.placeholder(tf.uint8,
                                             shape=labels_3d_shape,
                                             name='true_labels_3d')
        predicted_seg_1hot_3d_pl = tf.one_hot(predicted_seg_3d_pl,
                                              depth=exp_config.nlabels)

        # denoise the noisy segmentation
        _, pred_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l(
            predicted_seg_1hot_3d_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        l2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'l2l_mapper' in var_name:
                l2l_vars.append(v)

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_normalizer = tf.train.Saver(var_list=normalization_vars)
        saver_l2l = tf.train.Saver(var_list=l2l_vars)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the prior network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_l2l.restore(sess, checkpoint_path)

        # ================================================================
        # Make predictions for the image at the resolution of the image after pre-processing
        # ================================================================
        mask_predicted = []
        mask_predicted_soft = []
        img_normalized = []

        for b_i in range(0, image.shape[0], 1):

            X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1)

            mask_predicted.append(
                sess.run(predicted_seg, feed_dict={images_pl: X}))
            mask_predicted_soft.append(
                sess.run(predicted_seg_softmax, feed_dict={images_pl: X}))
            img_normalized.append(
                sess.run(images_normalized, feed_dict={images_pl: X}))

        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)
        mask_predicted_soft = np.squeeze(
            np.array(mask_predicted_soft)).astype(float)
        img_normalized = np.squeeze(np.array(img_normalized)).astype(float)

        # ================================================================
        # downsample predicted mask and pass it through the DAE
        # ================================================================
        if post_process is True:

            # downsample mask_predicted_soft
            mask_predicted_soft_downsampled = rescale(mask_predicted_soft, [
                1 / exp_config.downsampling_factor_x,
                1 / exp_config.downsampling_factor_y,
                1 / exp_config.downsampling_factor_z
            ],
                                                      order=1,
                                                      preserve_range=True,
                                                      multichannel=True,
                                                      mode='constant')
            mask_predicted_downsampled = np.argmax(
                mask_predicted_soft_downsampled, axis=-1)

            # pass the downsampled prediction through the DAE
            mask_predicted_denoised = mask_predicted_downsampled
            for _ in range(exp_config.dae_post_process_runs):
                feed_dict = {
                    predicted_seg_3d_pl:
                    np.expand_dims(mask_predicted_denoised, axis=0)
                }
                y_pred_noisy_denoised_softmax = np.squeeze(
                    sess.run(pred_seg_softmax_3d_noisy_autoencoded_softmax,
                             feed_dict=feed_dict)).astype(np.float16)
                mask_predicted_denoised = np.argmax(
                    y_pred_noisy_denoised_softmax, axis=-1)

            # upsample the denoised prediction
            mask_predicted = rescale(mask_predicted_denoised, [
                exp_config.downsampling_factor_x,
                exp_config.downsampling_factor_y,
                exp_config.downsampling_factor_z
            ],
                                     order=0,
                                     preserve_range=True,
                                     multichannel=False,
                                     mode='constant').astype(np.uint8)

        sess.close()

        return mask_predicted, img_normalized
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)
    already_created_recursion = False
    print("ALready created recursion : " + str(already_created_recursion))
    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(scribble_file=sys_config.project_root + exp_config.scribble_data,
                                                                                      target_folder='/scratch_net/biwirender02/cany/scribble/logdir/heart_dropout_welch_non_exp',
                                                                                      percent_full_sup=exp_config.percent_full_sup,
                                                                                      scr_ratio=exp_config.length_ratio
                                                                                      )
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion))
                
                except:
                    print("EXCEPTE GİRDİ")
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(init_step/(len(base_data['images_train'])/4))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning('!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...')
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')


        if loaded_previous_recursion:
            logging.info("Data file exists for recursion {} "
                         "but checkpoints only present up to recursion {}".format(recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")
            
            if not already_created_recursion:
                
                recursion_data = acdc_data.load_different_recursion(recursion_data, -1)
                recursion-=1
            else:
                start_epoch = 0
                init_step = 0
        # load images and validation data
        images_train = np.array(base_data['images_train'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
#        with tf.Graph().as_default():
        with tf.Session(config = config) as sess:
            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
            labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            keep_prob = tf.placeholder(tf.float32, shape=[])
            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     keep_prob,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm] = model.loss(logits,
                                                 labels_placeholder,
                                                 nlabels=exp_config.nlabels,
                                                 loss_type=exp_config.loss_type,
                                                 weight_decay=exp_config.weight_decay)  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)


            
            # Add to the Graph the Ops that calculate and apply gradients.
   
          

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss', val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
            train_error_summary = tf.summary.scalar('training_loss', train_error_)

            train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

            train_summary = tf.summary.merge([train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)
            


#            if continue_run:
#                # Restore session
#                saver.restore(sess, init_checkpoint_path)
            
#            saver.restore(sess,'/scratch_net/biwirender02/cany/scribble/logdir/heart_residual_crf/recursion_0_model_best_dice.ckpt-14199')
#
#            saver.restore(sess,"/scratch_net/biwirender02/cany/scribble/logdir/"+str(exp_config.experiment_name)+ "/recursion_0_model_best_dice.ckpt-26699")
##            
          
            recursion=0

            random_walked = np.array(recursion_data['random_walked'])

        
            recursion_data = predict_next_gt(data2=recursion_data,
                                                 images_train=images_train,
                                                 images_placeholder=images_placeholder,
                                                 training_time_placeholder=training_time_placeholder,
                                                 keep_prob=keep_prob,
                                                 logits=logits,
                                                 
                                                 sess=sess,
                                                 random_walked=random_walked,
                                                 )



    except Exception:
        raise
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    logging.info('Reading HCP - 3T - T1 images...')
    logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
    imtr, gttr = data_hcp.load_data(sys_config.orig_data_root_hcp,
                                    sys_config.preproc_folder_hcp, 'T1', 1,
                                    exp_config.train_size + 1)
    imvl, gtvl = data_hcp.load_data(
        sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1',
        exp_config.train_size + 1,
        exp_config.train_size + exp_config.val_size + 1)
    logging.info(
        'Training Images: %s' %
        str(imtr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info(
        'Training Labels: %s' %
        str(gttr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info('Validation Images: %s' % str(imvl.shape))
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # ================================================================
    # Define the segmentation network here
    # ================================================================
    mask = ~(imtr == 0).all(axis=(1, 2))
    imtr = imtr[mask]
    gttr = gttr[mask]

    mask = ~(imvl == 0).all(axis=(1, 2))
    imvl = imvl[mask]
    gtvl = gtvl[mask]

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        # Placeholders for the images and labels
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        images_pl = tf.compat.v1.placeholder(tf.float32,
                                             shape=image_tensor_shape,
                                             name='images')
        labels_tensor_shape_segmentation = [exp_config.batch_size] + list(
            exp_config.image_size)
        labels_tensor_shape_self_supervised = [exp_config.batch_size] + [
            exp_config.num_rotation_values
        ]
        labels_segmentation_pl = tf.compat.v1.placeholder(
            tf.uint8,
            shape=labels_tensor_shape_segmentation,
            name='labels_segmentation')
        labels_self_supervised_pl = tf.compat.v1.placeholder(
            tf.float16,
            shape=labels_tensor_shape_self_supervised,
            name='labels_self_supervised')
        # Placeholders for the learning rate and to indicate whether the model is being trained or tested
        learning_rate_pl = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[],
                                                    name='learning_rate')
        training_pl = tf.compat.v1.placeholder(tf.bool,
                                               shape=[],
                                               name='training_or_testing')

        # ================================================================
        # Define the segmentation network here
        # ================================================================
        logits_segmentation = exp_config.model_handle_segmentation(
            images_pl, image_tensor_shape, training_pl, exp_config.nlabels)

        # ================================================================
        # determine trainable variables
        # ================================================================
        segmentation_vars = []
        self_supervised_vars = []
        test_time_opt_vars = []

        for v in tf.compat.v1.trainable_variables():
            var_name = v.name
            if 'rotation' in var_name:
                self_supervised_vars.append(v)
            elif 'segmentation' in var_name:
                segmentation_vars.append(v)
            elif 'normalization' in var_name:
                test_time_opt_vars.append(v)

        if exp_config.debug is True:
            logging.info('================================')
            logging.info('List of trainable variables in the graph:')
            for v in tf.compat.v1.trainable_variables():
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all segmentation variables:')
            for v in segmentation_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all self-supervised variables:')
            for v in self_supervised_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all test time variables:')
            for v in test_time_opt_vars:
                logging.info(v.name)

        # ================================================================
        # Add ops for calculation of the training loss
        # ================================================================
        loss_segmentation = model.loss(logits_segmentation,
                                       labels_segmentation_pl,
                                       nlabels=exp_config.nlabels,
                                       loss_type=exp_config.loss_type)
        tf.compat.v1.summary.scalar('loss_segmentation', loss_segmentation)

        # ================================================================
        # Add optimization ops.
        # ================================================================
        train_op_train_time = model.training_step(
            loss_segmentation, segmentation_vars,
            exp_config.optimizer_handle_training, learning_rate_pl)

        # ================================================================
        # Add ops for model evaluation
        # ================================================================
        eval_loss_segmentation = model.evaluation(
            logits_segmentation,
            labels_segmentation_pl,
            images_pl,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type)

        # ================================================================
        # Build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        summary = tf.compat.v1.summary.merge_all()

        # ================================================================
        # Add init ops
        # ================================================================
        init_g = tf.compat.v1.global_variables_initializer()
        init_l = tf.compat.v1.local_variables_initializer()

        # ================================================================
        # Find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.compat.v1.report_uninitialized_variables()

        # ================================================================
        # create savers for each domain
        # ================================================================
        max_to_keep = 15
        saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep)
        saver_best_da = tf.compat.v1.train.Saver()

        # ================================================================
        # Create session
        # ================================================================
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.compat.v1.Session(config=config)

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='tr_error_seg')
        tr_error_summary_seg = tf.compat.v1.summary.scalar(
            'training/loss_seg', tr_error_seg)
        tr_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='tr_dice_seg')
        tr_dice_summary_seg = tf.compat.v1.summary.scalar(
            'training/dice_seg', tr_dice_seg)
        tr_summary_seg = tf.compat.v1.summary.merge(
            [tr_error_summary_seg, tr_dice_summary_seg])

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='vl_error_seg')
        vl_error_summary_seg = tf.compat.v1.summary.scalar(
            'validation/loss_seg', vl_error_seg)
        vl_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='vl_dice_seg')
        vl_dice_summary_seg = tf.compat.v1.summary.scalar(
            'validation/dice_seg', vl_dice_seg)
        vl_summary_seg = tf.compat.v1.summary.merge(
            [vl_error_summary_seg, vl_dice_summary_seg])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.compat.v1.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_g)
        sess.run(init_l)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            logging.info(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_da = 0

        # ================================================================
        # run training epochs
        # ================================================================
        for epoch in range(exp_config.max_epochs):

            logging.info(
                '============================================================')
            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(imtr,
                                             gttr,
                                             batch_size=exp_config.batch_size):

                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y_seg, y_ss = batch

                if step == 0:
                    x_s = x
                    y_seg_s = y_seg
                    y_ss_s = y_ss

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y_seg.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_segmentation_pl: y_seg,
                    labels_self_supervised_pl: y_ss,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # update vars
                # ===========================
                _, loss_value = sess.run(
                    [train_op_train_time, loss_segmentation],
                    feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.2f (%.3f sec for the last step)' %
                        (step + 1, loss_value, duration))

                    # ===========================
                    # print values of a parameter (to debug)
                    # ===========================
                    if exp_config.debug is True:
                        var_value = tf.compat.v1.get_collection(
                            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES
                        )[0].eval(
                            session=sess
                        )  # can add name if you only want parameters with a specific name
                        logging.info('value of one of the parameters %f' %
                                     var_value[0, 0, 0, 0])

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss_seg, train_dice_seg
                     ] = do_eval(sess, eval_loss_segmentation, images_pl,
                                 labels_segmentation_pl, training_pl, imtr,
                                 gttr, exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary_seg,
                                              feed_dict={
                                                  tr_error_seg: train_loss_seg,
                                                  tr_dice_seg: train_dice_seg
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:

                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically
                # ===========================
                if step % exp_config.val_eval_frequency == 0:

                    # ===========================
                    # Evaluate against the validation set of each domain
                    # ===========================
                    logging.info('Validation Data Eval:')
                    [val_loss_seg,
                     val_dice_seg] = do_eval(sess, eval_loss_segmentation,
                                             images_pl, labels_segmentation_pl,
                                             training_pl, imvl, gtvl,
                                             exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary_seg,
                                              feed_dict={
                                                  vl_error_seg: val_loss_seg,
                                                  vl_dice_seg: val_dice_seg
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice/accuracy is the best yet
                    # ===========================
                    if val_dice_seg > best_da:
                        best_da = val_dice_seg
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_da.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice_seg)

                step += 1

        sess.close()
Пример #8
0
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    if exp_config.train_dataset is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=1040,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    if exp_config.train_dataset is 'HCPT2':
        logging.info('Reading HCPT2 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=20,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'CALTECH':
        logging.info('Reading CALTECH images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'CALTECH/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'STANFORD':
        logging.info('Reading STANFORD images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'STANFORD/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'IXI':
        logging.info('Reading IXI images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_ixi)
        data_brain_train = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=0,
            idx_end=12,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=12,
            idx_end=17,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    logging.info(
        'Training Images: %s' %
        str(imtr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info(
        'Training Labels: %s' %
        str(gttr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info('Validation Images: %s' % str(imvl.shape))
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)
        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module will be adapted for each test image
        # ================================================================
        images_normalized, _ = model.normalize(images_pl, exp_config,
                                               training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, _, _ = model.predict_i2l(images_normalized,
                                         exp_config,
                                         training_pl=training_pl)

        print('shape of inputs: ',
              images_pl.shape)  # (batch_size, 256, 256, 1)
        print('shape of logits: ',
              logits.shape)  # (batch_size, 256, 256, nlabels)

        # ================================================================
        # create a list of all vars that must be optimized wrt
        # ================================================================
        i2l_vars = []
        for v in tf.trainable_variables():
            i2l_vars.append(v)

        # ================================================================
        # add ops for calculation of the supervised training loss
        # ================================================================
        loss_op = model.loss(logits,
                             labels_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('loss', loss_op)

        # ================================================================
        # add optimization ops.
        # Create different ops according to the variables that must be trained
        # ================================================================
        print('creating training op...')
        train_op = model.training_step(loss_op,
                                       i2l_vars,
                                       exp_config.optimizer_handle,
                                       learning_rate_pl,
                                       update_bn_nontrainable_vars=True)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        print('creating eval op...')
        eval_loss = model.evaluation_i2l(logits,
                                         labels_pl,
                                         images_pl,
                                         nlabels=exp_config.nlabels,
                                         loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        print('creating summary op...')
        summary = tf.summary.merge_all()

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create saver
        # ================================================================
        saver = tf.train.Saver(max_to_keep=10)
        saver_best_dice = tf.train.Saver(max_to_keep=3)

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error')
        vl_error_summary = tf.summary.scalar('validation/loss', vl_error)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error')
        tr_error_summary = tf.summary.scalar('training/loss', tr_error)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of all variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of all variables:')
        for v in tf.trainable_variables():
            print(v.name)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            print(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_dice = 0

        # ================================================================
        # run training epochs
        # ================================================================
        while (step < exp_config.max_steps):

            if step % 1000 is 0:
                logging.info(
                    '============================================================'
                )
                logging.info('step %d' % step)

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(imtr,
                                             gttr,
                                             batch_size=exp_config.batch_size,
                                             train_or_eval='train'):

                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # ===========================
                # create the feed dict for this training iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # opt step
                # ===========================
                _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.3f (%.3f sec for the last step)' %
                        (step + 1, loss, duration))

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('Training Data Eval:')
                    train_loss, train_dice = do_eval(sess, eval_loss,
                                                     images_pl, labels_pl,
                                                     training_pl, imtr, gttr,
                                                     exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error: train_loss,
                                                  tr_dice: train_dice
                                              })

                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('Validation Data Eval:')
                    val_loss, val_dice = do_eval(sess, eval_loss, images_pl,
                                                 labels_pl, training_pl, imvl,
                                                 gtvl, exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error: val_loss,
                                                  vl_dice: val_dice
                                              })

                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    if val_dice > best_dice:
                        best_dice = val_dice
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_dice.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice)

                step += 1

        sess.close()
def score_data(input_folder,
               output_folder,
               model_path,
               exp_config,
               do_postprocessing=False,
               recursion=None):

    base_data = h5py.File(
        '/scratch_net/biwirender02/cany/scribble/scribble_data/prostate_divided_normalized.h5',
        'r')

    images = np.array(base_data['images_test'])
    labels = np.array(base_data['masks_test'])
    slices_val = np.array(base_data['slices_test'])

    training_time_placeholder = tf.placeholder(tf.bool, shape=[])
    batch_size = 4
    image_tensor_shape = [exp_config.batch_size] + list(
        exp_config.image_size) + [1]

    images_placeholder = tf.placeholder(tf.float32,
                                        shape=image_tensor_shape,
                                        name='images')

    logits_op = model.inference(images_placeholder,
                                exp_config.model_handle,
                                training=training_time_placeholder,
                                nlabels=exp_config.nlabels)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    dices = []
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    with tf.Session(config=config) as sess:

        sess.run(init)
        if recursion is None:
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                model_path, 'model_best_dice.ckpt')
        else:
            try:
                checkpoint_path = utils.get_latest_model_checkpoint_path(
                    model_path,
                    'recursion_{}_model_best_dice.ckpt'.format(recursion))
            except:
                checkpoint_path = utils.get_latest_model_checkpoint_path(
                    model_path, 'recursion_{}_model.ckpt'.format(recursion))

        saver.restore(sess, checkpoint_path)
        for patient in range(len(slices_val)):
            cur_slice = np.copy(np.uint(sum(slices_val[:patient])))

            whole_label = np.copy(
                labels[np.uint(cur_slice):np.uint(cur_slice +
                                                  slices_val[patient]), :, :])
            res = np.zeros((np.uint(slices_val[patient]), 320, 320))
            for sli in range(int(slices_val[patient])):
                x = np.zeros((4, 320, 320, 1))
                x[0, :, :, :] = np.copy(
                    np.expand_dims(images[int(sli + cur_slice), :, :], axis=3))
                #            y=labels[int(sli+cur_slice),:,:]

                softmax = tf.nn.softmax(tf.slice(logits_op, [0, 0, 0, 0],
                                                 [1, -1, -1, -1]),
                                        dim=-1)
                mask_op = tf.arg_max(softmax, dimension=-1)  # was 3

                feed_dict = {
                    images_placeholder: x,
                    training_time_placeholder: False
                }

                mask = sess.run(mask_op, feed_dict=feed_dict)

                res[sli, :, :] = np.copy(mask)

            temp_dices = []
            temp_preds = []
            temp_labels = []
            for c in [1, 2]:
                # Copy the gt image to not alterate the input
                gt_c_i = np.copy(whole_label)
                gt_c_i[gt_c_i != c] = 0

                # Copy the pred image to not alterate the input
                pred_c_i = np.copy(res)
                pred_c_i[pred_c_i != c] = 0

                # Clip the value to compute the volumes
                gt_c_i = np.clip(gt_c_i, 0, 1)
                pred_c_i = np.clip(pred_c_i, 0, 1)

                temp_preds.append(pred_c_i)
                temp_labels.append(gt_c_i)
                # Compute the Dice

                dice = my_dice(gt_c_i, pred_c_i)
                #            dice = dc(gt_c_i, pred_c_i)
                temp_dices.append(dice)

            dices.append(temp_dices)
            logging.info("Dice for patient : " + str(temp_dices[0]) + " and " +
                         str(temp_dices[1]))
            if patient == 0:
                all_preds = np.asarray(temp_preds)
                all_labels = np.asarray(temp_labels)
            else:
                all_preds = np.concatenate(
                    [all_preds, np.asarray(temp_preds)], axis=1)
                all_labels = np.concatenate(
                    [all_labels, np.asarray(temp_labels)], axis=1)
                # This is the same for 2D and 3D again


#            if do_postprocessing:
#                prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)
#

# Save predicted mask
            out_file_name = os.path.join(output_folder, 'prediction',
                                         'patient' + str(patient) + '.nii.gz')
            #            print(str(res.shape))
            #            print(str(whole_label.shape))
            #
            #            logging.info('saving to: %s' % out_file_name)

            save_nii(out_file_name, np.asarray(res))

            # Save GT image
            gt_file_name = os.path.join(output_folder, 'ground_truth',
                                        'patient' + str(patient) + '.nii.gz')
            #            logging.info('saving to: %s' % gt_file_name)
            save_nii(gt_file_name, np.uint8(np.asarray(whole_label)))

            #            # Save difference mask between predictions and ground truth
            #            difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0])
            #            difference_mask = np.asarray(difference_mask, dtype=np.uint8)
            #            diff_file_name = os.path.join(output_folder,
            #                                          'difference',
            #                                          'patient' + str(k) + '.nii.gz')
            #            logging.info('saving to: %s' % diff_file_name)
            #            save_nii(diff_file_name, difference_mask)

            # Save image data to the same folder for convenience
            image_file_name = os.path.join(
                output_folder, 'image', 'patient' + str(patient) + '.nii.gz')
            #            logging.info('saving to: %s' % image_file_name)
            save_nii(
                image_file_name,
                np.uint8(255 * np.asarray(images[
                    np.uint(cur_slice):np.uint(cur_slice +
                                               slices_val[patient]), :, :])))

    return 0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(
        scribble_file=sys_config.project_root + exp_config.scribble_data,
        target_folder=log_dir,
        percent_full_sup=exp_config.percent_full_sup,
        scr_ratio=exp_config.length_ratio)
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir, 'recursion_{}_model.ckpt'.format(recursion))
                except:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir,
                        'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(
                    init_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(
                    init_step /
                    (len(base_data['images_train']) / exp_config.batch_size))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning(
                    '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
                )
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )

        if loaded_previous_recursion:
            logging.info(
                "Data file exists for recursion {} "
                "but checkpoints only present up to recursion {}".format(
                    recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")
            recursion_data = acdc_data.load_different_recursion(
                recursion_data, -1)
            recursion -= 1

        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.

        with tf.Graph().as_default():

            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32,
                                                shape=image_tensor_shape,
                                                name='images')
            labels_placeholder = tf.placeholder(tf.uint8,
                                                shape=mask_tensor_shape,
                                                name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm
             ] = model.loss(logits,
                            labels_placeholder,
                            nlabels=exp_config.nlabels,
                            loss_type=exp_config.loss_type,
                            weight_decay=exp_config.weight_decay
                            )  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)

            # Add to the Graph the Ops that calculate and apply gradients.
            if exp_config.momentum is not None:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder,
                                               momentum=exp_config.momentum)
            else:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder)

            # Add the Op to compare the logits to the labels during evaluation.
#            eval_loss = model.evaluation(logits,
#                                         labels_placeholder,
#                                         images_placeholder,
#                                         nlabels=exp_config.nlabels,
#                                         loss_type=exp_config.loss_type,
#                                         weak_supervision=True,
#                                         cnn_threshold=exp_config.cnn_threshold,
#                                         include_bg=True)
            eval_val_loss = model.evaluation(
                logits,
                labels_placeholder,
                images_placeholder,
                nlabels=exp_config.nlabels,
                loss_type=exp_config.loss_type,
                weak_supervision=True,
                cnn_threshold=exp_config.cnn_threshold,
                include_bg=False)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss',
                                                  val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge(
                [val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='train_error')
            train_error_summary = tf.summary.scalar('training_loss',
                                                    train_error_)

            train_dice_ = tf.placeholder(tf.float32,
                                         shape=[],
                                         name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice',
                                                   train_dice_)

            train_summary = tf.summary.merge(
                [train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)

            # Restore session
            #            crf_weights = []
            #            for v in tf.all_variables():
            #
            #                if v.name[0:4]=='bila':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='spat':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='comp':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #            restore_var = [v for v in tf.all_variables() if v.name not in crf_weights]
            #
            #            load_saver = tf.train.Saver(var_list=restore_var)
            #            load_saver.restore(sess, '/scratch_net/biwirender02/cany/basil/logdir/unet2D_ws_spot_blur/recursion_0_model.ckpt-5699')

            if continue_run:
                # Restore session
                saver.restore(sess, init_checkpoint_path)

            step = init_step
            curr_lr = exp_config.learning_rate

            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

            # random walk - if it already has been random walked it won't redo
            recursion_data = acdc_data.random_walk_epoch(
                recursion_data, exp_config.rw_beta, exp_config.rw_threshold,
                exp_config.random_walk)

            #get ground truths
            labels_train = np.array(recursion_data['random_walked'])

            for epoch in range(start_epoch, exp_config.max_epochs):
                if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0) \
                        or loaded_previous_recursion:
                    loaded_previous_recursion = False
                    #Have reached end of recursion
                    recursion_data = predict_next_gt(
                        data=recursion_data,
                        images_train=images_train,
                        images_placeholder=images_placeholder,
                        training_time_placeholder=training_time_placeholder,
                        logits=logits,
                        sess=sess)

                    recursion_data = postprocess_gt(
                        data=recursion_data,
                        images_train=images_train,
                        scribbles_train=scribbles_train)
                    recursion += 1
                    # random walk - if it already has been random walked it won't redo
                    recursion_data = acdc_data.random_walk_epoch(
                        recursion_data, exp_config.rw_beta,
                        exp_config.rw_threshold, exp_config.random_walk)
                    #get ground truths
                    labels_train = np.array(recursion_data['random_walked'])

                    #reinitialise savers - otherwise, no checkpoints will be saved for each recursion
                    saver = tf.train.Saver(max_to_keep=2)
                    saver_best_dice = tf.train.Saver(max_to_keep=2)
                    saver_best_xent = tf.train.Saver(max_to_keep=2)
                logging.info(
                    'Epoch {0} ({1} of {2} epochs for recursion {3})'.format(
                        epoch, 1 + epoch % exp_config.epochs_per_recursion,
                        exp_config.epochs_per_recursion, recursion))
                # for batch in iterate_minibatches(images_train,
                #                                  labels_train,
                #                                  batch_size=exp_config.batch_size,
                #                                  augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning even though the code has
                # crashed.

                for batch in BackgroundGenerator(
                        iterate_minibatches(
                            images_train,
                            labels_train,
                            batch_size=exp_config.batch_size,
                            augment_batch=exp_config.augment_batch)):

                    if exp_config.warmup_training:
                        if step < 50:
                            curr_lr = exp_config.learning_rate / 10.0
                        elif step == 50:
                            curr_lr = exp_config.learning_rate

                    start_time = time.time()

                    # batch = bgn_train.retrieve()
                    x, y = batch

                    # TEMPORARY HACK (to avoid incomplete batches
                    if y.shape[0] < exp_config.batch_size:
                        step += 1
                        continue

                    feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        learning_rate_placeholder: curr_lr,
                        training_time_placeholder: True
                    }

                    _, loss_value = sess.run([train_op, loss],
                                             feed_dict=feed_dict)

                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.
                        logging.info('Step %d: loss = %.6f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()


#                    if (step + 1) % exp_config.train_eval_frequency == 0:
#
#                        logging.info('Training Data Eval:')
#                        [train_loss, train_dice] = do_eval(sess,
#                                                           eval_loss,
#                                                           images_placeholder,
#                                                           labels_placeholder,
#                                                           training_time_placeholder,
#                                                           images_train,
#                                                           labels_train,
#                                                           exp_config.batch_size)
#
#                        train_summary_msg = sess.run(train_summary, feed_dict={train_error_: train_loss,
#                                                                               train_dice_: train_dice}
#                                                     )
#                        summary_writer.add_summary(train_summary_msg, step)
#
#                        loss_history.append(train_loss)
#                        if len(loss_history) > 5:
#                            loss_history.pop(0)
#                            loss_gradient = (loss_history[-5] - loss_history[-1]) / 2
#
#                        logging.info('loss gradient is currently %f' % loss_gradient)
#
#                        if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
#                            logging.warning('Reducing learning rate!')
#                            curr_lr /= 10.0
#                            logging.info('Learning rate changed to: %f' % curr_lr)
#
#                            # reset loss history to give the optimisation some time to start decreasing again
#                            loss_gradient = np.inf
#                            loss_history = []
#
#                        if train_loss <= last_train:  # best_train:
#                            logging.info('Decrease in training error!')
#                        else:
#                            logging.info('No improvement in training error for %d steps' % no_improvement_counter)
#
#                        last_train = train_loss

# Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(
                            log_dir,
                            'recursion_{}_model.ckpt'.format(recursion))
                        saver.save(sess, checkpoint_file, global_step=step)
                        # Evaluate against the training set.

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss, val_dice
                         ] = do_eval(sess, eval_val_loss, images_placeholder,
                                     labels_placeholder,
                                     training_time_placeholder, images_val,
                                     labels_val, exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_dice.ckpt'.format(
                                    recursion))
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - {} - '
                                'Saving recursion_{}_model_best_dice.ckpt'.
                                format(val_dice, recursion))

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_xent.ckpt'.format(
                                    recursion))
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - {} - '
                                'Saving recursion_{}_model_best_xent.ckpt'.
                                format(val_loss, recursion))

                    step += 1

    except Exception:
        raise
Пример #11
0
def generate_adversarial_examples(input_folder,
                                  output_path,
                                  model_path,
                                  attack,
                                  attack_args,
                                  exp_config,
                                  add_gaussian=False):
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    mask_tensor_shape = [batch_size] + list(exp_config.image_size)
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    labels_pl = tf.placeholder(tf.uint8,
                               shape=mask_tensor_shape,
                               name='labels')
    logits_pl = model.inference(images_pl,
                                exp_config=exp_config,
                                training=tf.constant(False, dtype=tf.bool))
    eval_loss = model.evaluation(logits_pl,
                                 labels_pl,
                                 images_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type)

    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=True)

    images = data['images_test'][:20]
    labels = data['masks_test'][:20]

    print("Num images train {} test {}".format(len(data['images_train']),
                                               len(images)))

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    baseline_closs = 0.0
    baseline_cdice = 0.0
    attack_closs = 0.0
    attack_cdice = 0.0
    l2_diff_sum = 0.0
    ln_diff_sum = 0.0
    ln_diff = 0.0
    l2_diff = 0.0
    batches = 0
    result_dict = []

    with tf.Session() as sess:
        results = []
        sess.run(init)
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        for batch in BackgroundGenerator(
                train.iterate_minibatches(images, labels, batch_size)):
            x, y = batch
            batches += 1

            if batches != 9:
                continue

            non_adv_mask_out = sess.run(
                [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                feed_dict={images_pl: x})

            if attack == 'fgsm':
                adv_x = adv_attack.fgsm_run(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            attack_args)
            elif attack == 'pgd':
                adv_x = adv_attack.pgd(x, y, images_pl, labels_pl, logits_pl,
                                       exp_config, sess, attack_args)
            elif attack == 'spgd':
                adv_x = adv_attack.pgd_conv(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            **attack_args)
            else:
                raise NotImplementedError
            adv_x = [adv_x]

            if add_gaussian:
                print('adding gaussian noise')
                adv_x = adv_attack.add_gaussian_noise(
                    x,
                    adv_x[0],
                    sess,
                    eps=attack_args['eps'],
                    sizes=attack_args['sizes'],
                    weights=attack_args['weights'])

            for i in range(len(adv_x)):
                l2_diff = np.average(
                    np.squeeze(np.linalg.norm(adv_x[i] - x, axis=(1, 2))))
                ln_diff = np.average(
                    np.squeeze(
                        np.linalg.norm(adv_x[i] - x, axis=(1, 2), ord=np.inf)))

                l2_diff_sum += l2_diff
                ln_diff_sum += ln_diff

                print(l2_diff, l2_diff)

                adv_mask_out = sess.run(
                    [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                    feed_dict={images_pl: adv_x[i]})

                closs, cdice = sess.run(eval_loss,
                                        feed_dict={
                                            images_pl: x,
                                            labels_pl: y
                                        })
                baseline_closs = closs + baseline_closs
                baseline_cdice = cdice + baseline_cdice

                adv_closs, adv_cdice = sess.run(eval_loss,
                                                feed_dict={
                                                    images_pl: adv_x[i],
                                                    labels_pl: y
                                                })
                attack_closs = adv_closs + attack_closs
                attack_cdice = adv_cdice + attack_cdice

                partial_result = dict({
                    'attack': attack,
                    'attack_args': {
                        k: attack_args[k]
                        for k in ['eps', 'step_alpha', 'epochs']
                    },  #
                    'baseline_closs': closs,
                    'baseline_cdice': cdice,
                    'attack_closs': adv_closs,
                    'attack_cdice': adv_cdice,
                    'attack_l2_diff': l2_diff,
                    'attack_ln_diff': ln_diff
                })

                jsonString = json.dumps(str(partial_result))

                #results.append(copy.deepcopy(result_dict))

                with open(
                        "eval_results/{}-{}-{}-{}-metrics.json".format(
                            attack, add_gaussian, batches, i),
                        "w") as jsonFile:
                    jsonFile.write(jsonString)

                image_gt = "eval_results/ground-truth-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(y), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_gt, format='pdf')
                plt.clf()

                image_benign = "eval_results/benign-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(non_adv_mask_out),
                           cmap='viridis',
                           alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_benign, format='pdf')
                plt.clf()

                image_adv = "eval_results/adversarial-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                plt.imshow(np.squeeze(adv_mask_out), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                image_adv_input = "eval_results/adv-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.tight_layout()
                plt.axis('off')
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(x), cmap='gray')
                image_adv_input = "eval_results/benign-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                print(attack_closs, attack_cdice, l2_diff, ln_diff)

        print("Evaluation results")
        print("{} Attack Params {}".format(attack, attack_args))
        print("Baseline metrics: Avg loss {}, Avg DICE Score {} ".format(
            baseline_closs / (batches * len(adv_x)),
            baseline_cdice / (batches * len(adv_x))))
        print(
            "{} Attack effectiveness: Avg loss {}, Avg DICE Score {} ".format(
                attack, attack_closs / (batches * len(adv_x)),
                attack_cdice / (batches * len(adv_x))))
        print(
            "{} Attack visibility: Avg l2-norm diff {} Avg l-inf-norm diff {}".
            format(attack, l2_diff_sum / (batches * len(adv_x)),
                   ln_diff_sum / (batches * len(adv_x))))
        result_dict = dict({
            'attack': attack,
            'attack_args':
            {k: attack_args[k]
             for k in ['eps', 'step_alpha', 'epochs']},  #
            'baseline_closs_avg': baseline_closs / batches,
            'baseline_cdice_avg': baseline_cdice / batches,
            'attack_closs_avg': attack_closs / batches,
            'attack_cdice_avg': attack_cdice / batches,
            'attack_l2_diff': l2_diff_sum / batches,
            'attack_ln_diff': ln_diff_sum / batches
        })

        results.append(copy.deepcopy(result_dict))
        print(results)

        jsonString = json.dumps(results)
        with open("eval_results/{}-results.json".format(attack),
                  "w") as jsonFile:
            jsonFile.write(jsonString)
Пример #12
0
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None):

    nx, ny = config.image_size[:2]
    batch_size = 1
    num_channels = config.nlabels

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    # According to the experiment config, pick a model and predict the output
    # TODO: Implement majority voting using 3 models.
    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    evaluate_test_set = not gt_exists

    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0

        for folder in os.listdir(input_folder):

            folder_path = os.path.join(input_folder, folder)

            if os.path.isdir(folder_path):

                if evaluate_test_set or evaluate_all:
                    train_test = 'test'  # always test
                else:
                    train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'


                if train_test == 'test':

                    infos = {}
                    for line in open(os.path.join(folder_path, 'Info.cfg')):
                        label, value = line.split(':')
                        infos[label] = value.rstrip('\n').lstrip(' ')

                    patient_id = folder.lstrip('patient')
                    ED_frame = int(infos['ED'])
                    ES_frame = int(infos['ES'])

                    for file in glob.glob(os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                        logging.info(' ----- Doing image: -------------------------')
                        logging.info('Doing: %s' % file)
                        logging.info(' --------------------------------------------')

                        file_base = file.split('.nii.gz')[0]

                        frame = int(file_base.split('frame')[-1])
                        img_dat = utils.load_nii(file)
                        img = img_dat[0].copy()
                        #img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)
                        #img = image_utils.normalize_image(img)
                        print('img')
                        print(img.shape)
                        print(img.dtype)

                        if gt_exists:
                            file_mask = file_base + '_gt.nii.gz'
                            mask_dat = utils.load_nii(file_mask)
                            mask = mask_dat[0]

                        start_time = time.time()

                        if config.data_mode == '2D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2])
                            scale_vector = (pixel_size[0] / config.target_resolution[0],
                                            pixel_size[1] / config.target_resolution[1])
                            print('pixel_size', pixel_size)
                            print('scale_vector', scale_vector)

                            predictions = []
                            mask_arr = []
                            img_arr = []

                            for zz in range(img.shape[2]):

                                slice_img = np.squeeze(img[:,:,zz])
                                slice_rescaled = transform.rescale(slice_img,
                                                                   scale_vector,
                                                                   order=1,
                                                                   preserve_range=True,
                                                                   multichannel=False,
                                                                   anti_aliasing=True,
                                                                   mode='constant')
                                print('slice_img', slice_img.shape)
                                print('slice_rescaled', slice_rescaled.shape)
                                
                                slice_mask = np.squeeze(mask[:, :, zz])
                                mask_rescaled = transform.rescale(slice_mask,
                                                                  scale_vector,
                                                                  order=0,
                                                                  preserve_range=True,
                                                                  multichannel=False,
                                                                  anti_aliasing=True,
                                                                  mode='constant')

                                slice_cropped = acdc_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                                print('slice_cropped', slice_cropped.shape)
                                mask_cropped = acdc_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)
                                
                                slice_cropped = np.float32(slice_cropped)
                                mask_cropped = np.asarray(mask_cropped, dtype=np.uint8)
                                
                                x = image_utils.reshape_2Dimage_to_tensor(slice_cropped)
                                y = image_utils.reshape_2Dimage_to_tensor(mask_cropped)
                                
                                # GET PREDICTION
                                feed_dict = {
                                images_pl: x,
                                }
                                
                                mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

                            
                                prediction_cropped = np.squeeze(logits_out[0,...])

                                # ASSEMBLE BACK THE SLICES
                                slice_predictions = np.zeros((nx,ny,num_channels))
                                slice_predictions = prediction_cropped
                                # RESCALING ON THE LOGITS
                                if gt_exists:
                                    prediction = transform.resize(slice_predictions,
                                                                  (nx, ny, num_channels),
                                                                  order=1,
                                                                  preserve_range=True,
                                                                  anti_aliasing=True,
                                                                  mode='constant')
                                

                                # prediction = transform.resize(slice_predictions,
                                #                               (mask.shape[0], mask.shape[1], num_channels),
                                #                               order=1,
                                #                               preserve_range=True,
                                #                               mode='constant')

                                prediction = np.uint8(np.argmax(prediction, axis=-1))
                                predictions.append(prediction)
                                mask_arr.append(np.squeeze(y))
                                img_arr.append(np.squeeze(x))
                                
                            prediction_arr = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0))
                            mask_arrs = np.transpose(np.asarray(mask_arr, dtype=np.uint8), (1,2,0))
                            img_arrs = np.transpose(np.asarray(img_arr, dtype=np.float32), (1,2,0))                   

                            
                        # This is the same for 2D and 3D again
                        if do_postprocessing:
                            prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)

                        elapsed_time = time.time() - start_time
                        total_time += elapsed_time
                        total_volumes += 1

                        logging.info('Evaluation of volume took %f secs.' % elapsed_time)

                        if frame == ED_frame:
                            frame_suffix = '_ED'
                        elif frame == ES_frame:
                            frame_suffix = '_ES'
                        else:
                            raise ValueError('Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' %
                                             (frame, ED_frame, ES_frame))

                        # Save prediced mask
                        out_file_name = os.path.join(output_folder, 'prediction',
                                                     'patient' + patient_id + frame_suffix + '.nii.gz')
                        if gt_exists:
                            out_affine = mask_dat[1]
                            out_header = mask_dat[2]
                        else:
                            out_affine = img_dat[1]
                            out_header = img_dat[2]

                        logging.info('saving to: %s' % out_file_name)
                        utils.save_nii(out_file_name, prediction_arr, out_affine, out_header)

                        # Save image data to the same folder for convenience
                        image_file_name = os.path.join(output_folder, 'image',
                                                'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % image_file_name)
                        utils.save_nii(image_file_name, img_dat[0], out_affine, out_header)

                        if gt_exists:

                            # Save GT image
                            gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % gt_file_name)
                            utils.save_nii(gt_file_name, mask_arrs, out_affine, out_header)

                            # Save difference mask between predictions and ground truth
                            difference_mask = np.where(np.abs(prediction_arr-mask_arrs) > 0, [1], [0])
                            difference_mask = np.asarray(difference_mask, dtype=np.uint8)
                            
                   #         for zz in range(difference_mask.shape[2]):
                   #       
                   #             fig = plt.figure()
                   #             ax1 = fig.add_subplot(221)
                   #             ax1.set_axis_off()
                   #             ax1.imshow(img_arrs[:,:,zz])
                   #             ax2 = fig.add_subplot(222)
                   #             ax2.set_axis_off()
                   #             ax2.imshow(mask_arrs[:,:,zz])
                   #             ax3 = fig.add_subplot(223)
                   #             ax3.set_axis_off()
                   #             ax3.imshow(prediction_arr[:,:,zz])
                   #             ax1.title.set_text('a')
                   #             ax2.title.set_text('b')
                   #             ax3.title.set_text('c')
                   #             ax4 = fig.add_subplot(224)
                   #             ax4.set_axis_off()
                   #             ax4.imshow(difference_mask[:,:,zz], cmap=plt.cm.gnuplot)
                   #             ax1.title.set_text('a')
                   #             ax2.title.set_text('b')
                   #             ax3.title.set_text('c')
                   #             ax4.title.set_text('d')
                   #             plt.gray()
                   #             plt.show()
                            
                        
                            for zz in range(difference_mask.shape[2]):
                                plt.imshow(img_arrs[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                plt.imshow(mask_arrs[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                plt.imshow(prediction_arr[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                print('...')

                            diff_file_name = os.path.join(output_folder,
                                                          'difference',
                                                          'patient' + patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % diff_file_name)
                            utils.save_nii(diff_file_name, difference_mask, out_affine, out_header)


        logging.info('Average time per volume: %f' % (total_time/total_volumes))

    return init_iteration
Пример #13
0
def main(exp_config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False
    )

    batch_size = 1

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl = model.predict(images_pl, exp_config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()


    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        while True:

            ind = np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind,...]
            y = data['masks_test'][ind,...]

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)

            feed_dict = {
                images_pl: x,
            }

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

            fig = plt.figure()
            ax1 = fig.add_subplot(241)
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(242)
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(243)
            ax3.imshow(np.squeeze(mask_out))

            ax5 = fig.add_subplot(245)
            ax5.imshow(np.squeeze(softmax_out[...,0]))
            ax6 = fig.add_subplot(246)
            ax6.imshow(np.squeeze(softmax_out[...,1]))
            ax7 = fig.add_subplot(247)
            ax7.imshow(np.squeeze(softmax_out[...,2]))
            ax8 = fig.add_subplot(248)
            ax8.imshow(np.squeeze(softmax_out[...,3]))

            plt.show()

    data.close()
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(scribble_file=sys_config.project_root + exp_config.scribble_data,
                                                                                      target_folder=log_dir,
                                                                                      percent_full_sup=exp_config.percent_full_sup,
                                                                                      scr_ratio=exp_config.length_ratio
                                                                                      )
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion))
                except:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(init_step/(len(base_data['images_train'])/4))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning('!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...')
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')


        if loaded_previous_recursion:
            logging.info("Data file exists for recursion {} "
                         "but checkpoints only present up to recursion {}".format(recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")
            recursion_data = acdc_data.load_different_recursion(recursion_data, -1)
            recursion-=1

        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])


        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
#        with tf.Graph().as_default():
        with tf.Session(config = config) as sess:
            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
            labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)


            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()



            # Run the Op to initialize the variables.
            sess.run(init)
            
            saver = tf.train.Saver(max_to_keep=2)

            if continue_run:
                # Restore session
                saver.restore(sess, init_checkpoint_path)

            step = init_step
            curr_lr = exp_config.learning_rate

            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

#            # random walk - if it already has been random walked it won't redo
#            recursion_data = acdc_data.random_walk_epoch(recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk)


                    #Have reached end of recursion
            recursion_data = predict_next_gt(data=recursion_data,
                                             images_train=images_train,
                                             images_placeholder=images_placeholder,
                                             training_time_placeholder=training_time_placeholder,
                                             logits=logits,
                                             sess=sess)

            recursion_data = postprocess_gt(data=recursion_data,
                                            images_train=images_train,
                                            scribbles_train=scribbles_train)
            recursion += 1
            # random walk - if it already has been random walked it won't redo
            recursion_data = acdc_data.random_walk_epoch(recursion_data,
                                                         exp_config.rw_beta,
                                                         exp_config.rw_threshold,
                                                         exp_config.random_walk)
    except Exception:
        raise
Пример #15
0
def predict_segmentation(image):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + [exp_config.image_size[0]] +
                                   [exp_config.image_size[1]] + [1],
                                   name='images')

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, softmax, preds = model.predict_i2l(images_pl,
                                                   exp_config,
                                                   training_pl=tf.constant(
                                                       False, dtype=tf.bool))

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver()

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # predict segmentation
        # ================================================================
        X = np.expand_dims(np.expand_dims(image, axis=-1), axis=0)
        mask_predicted = sess.run(preds, feed_dict={images_pl: X})
        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)

        sess.close()

        return mask_predicted
Пример #16
0
def main(csv_path=None, batch_size=None):
    #get number of recursions
    num_recursions = acdc_data.most_recent_recursion(model_path)

    #get data
    base_data = h5py.File(os.path.join(model_path, 'base_data.hdf5'), 'r')

    masks_gt = np.array(base_data['masks_train'])
    images = np.array(base_data['images_train'])
    base_data.close()
    r_data = h5py.File(os.path.join(model_path, 'recursion_0_data.hdf5'), 'r')
    scribbles = np.array(r_data['predicted'])
    r_data.close()

    with h5py.File(model_path +
                   "/recursion_evaluation/assessment.hdf5") as output_data:
        #FULLY RANDOM WALKED
        if not 'random_walk_mask' in output_data:
            output_data.create_dataset(name='random_walk_mask',
                                       dtype=np.uint8,
                                       shape=images.shape)
            output_data['random_walk_mask'].attrs.create('processed_to',
                                                         dtype=np.uint16,
                                                         data=0)
            initialise_dice_attrs(output_data, 'random_walk_mask',
                                  exp_config.nlabels)

        #process random walk
        processed_to = output_data['random_walk_mask'].attrs.get(
            'processed_to') + 1
        for scr_idx in range(processed_to, len(images), exp_config.batch_size):
            ind = list(
                range(scr_idx, min(scr_idx + exp_config.batch_size,
                                   len(images))))
            logging.info(
                'Preparing random walker segmentation for slices {} to {}'.
                format(ind[0], ind[-1]))
            output_data['random_walk_mask'][ind, ...] = scribbles[ind, ...]
            # output_data['random_walk_mask'][ind, ...] = rw(images[ind, ...],
            #                                                scribbles[ind, ...],
            #                                                beta=exp_config.rw_beta,
            #                                                threshold=0)
            # output_data['random_walk_mask'].attrs.modify('processed_to', ind[-1] + 1)

        #calculate dice
        _, dices, mean_dice, stdevs = calculate_dices(
            np.array(output_data['random_walk_mask']),
            masks_gt,
            exp_config.nlabels,
            path=csv_path,
            filename="fully_random_walked")
        output_data['random_walk_mask'].attrs.modify('mean_dice', mean_dice)
        output_data['random_walk_mask'].attrs.modify('dices', dices)
        output_data['random_walk_mask'].attrs.modify('stdevs', stdevs)

        #iterate through recursions
        #recursion 0 slightly different as it has no 'output from previous epoch' set
        r_data = h5py.File(os.path.join(model_path, 'recursion_0_data.hdf5'),
                           'r')
        dset_names = np.array([[
            'recursion_{0}/processed_{0}', 'recursion_{0}/random_walked_{0}',
            'recursion_{0}/prediction_{0}'
        ], ['postprocessed', 'random_walked', 'predicted'
            ], ['processed r{0} seg', 'random walked r{0} seg', 'r{0} seg']])

        for recursion in range(0, num_recursions + 1):
            for i in range(3):
                dset_name = dset_names[0, i].format(recursion)
                #once it gets to the prediction it must open the next dataset file
                if i == 2:
                    r_data.close()
                    if recursion == num_recursions:
                        break
                    r_data = h5py.File(
                        os.path.join(
                            model_path,
                            'recursion_{}_data.hdf5'.format(recursion + 1)),
                        'r')

                if not dset_name in output_data:
                    logging.info("Creating dataset {}".format(dset_name))
                    logging.info("Getting data from {}[{}]".format(
                        os.path.basename(r_data.filename).split('.')[0],
                        dset_names[1, i]))
                    output_data.create_dataset(dset_name,
                                               dtype=np.uint8,
                                               shape=images.shape,
                                               data=r_data[dset_names[1, i]])
                    initialise_dice_attrs(output_data, dset_name,
                                          exp_config.nlabels)
                else:
                    logging.info("Getting data from {}[{}]".format(
                        os.path.basename(r_data.filename).split('.')[0],
                        dset_names[1, i]))
                    output_data[dset_name][:] = np.array(r_data[dset_names[1,
                                                                           i]])
                #calculate dices
                indices, dices, mean_dice, stdevs = calculate_dices(
                    np.array(output_data[dset_name]),
                    masks_gt,
                    exp_config.nlabels,
                    path=csv_path,
                    filename=dset_names[2, i].format(recursion))

                output_data[dset_name].attrs.modify('mean_dice', mean_dice)
                output_data[dset_name].attrs.modify('dices', dices)
                output_data[dset_name].attrs.modify('stdevs', stdevs)

        #for last recursion, need to use prediction from network
        final_dset = 'recursion_{0}/prediction_{0}'.format(recursion)
        if not final_dset in output_data:
            output_data.create_dataset(final_dset,
                                       dtype=np.uint8,
                                       shape=images.shape,
                                       data=masks_gt)
            output_data[final_dset].attrs.create(name='processed_to',
                                                 data=np.array((0, )),
                                                 shape=(1, ),
                                                 dtype=np.uint16)
            output_data[final_dset].attrs.create(name='processed',
                                                 data=np.array(False),
                                                 shape=(),
                                                 dtype=np.bool)
            initialise_dice_attrs(output_data, final_dset, exp_config.nlabels)

        try:
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                model_path, 'recursion_{}_model*.ckpt'.format(recursion))
            skip_final_recursion = False
        except:
            skip_final_recursion = True

        if not skip_final_recursion:
            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            images_pl = tf.placeholder(tf.float32,
                                       shape=image_tensor_shape,
                                       name='images')
            mask_pl, softmax_pl = model.predict(images_pl,
                                                exp_config.model_handle,
                                                exp_config.nlabels)
            saver = tf.train.Saver()
            init = tf.global_variables_initializer()
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            with tf.Session() as sess:
                sess.run(init)
                try:
                    saver.restore(sess, checkpoint_path)
                    epoch = int(
                        checkpoint_path.split('/')[-1].split('-')[-1]) + 1
                    epoch = int(epoch / (len(images) / exp_config.batch_size))
                    epoch = epoch % exp_config.epochs_per_recursion
                    skip_final_recursion = False
                except:
                    logging.info(
                        "Failed to load checkpoint for recursion {}".format(
                            recursion))
                    skip_final_recursion = True

                if not skip_final_recursion:
                    scr_max = len(images)
                    processed_to = output_data[final_dset].attrs.get(
                        'processed_to')
                    print(processed_to)
                    processed_to = 0 if processed_to is None else processed_to
                    processed_to = np.squeeze(processed_to)

                    for scr_idx in range(processed_to, len(images),
                                         exp_config.batch_size):
                        if scr_idx + exp_config.batch_size > scr_max:
                            # At the end of the dataset
                            ind = list(
                                range(scr_max - exp_config.batch_size,
                                      scr_max))
                        else:
                            ind = list(
                                range(scr_idx,
                                      scr_idx + exp_config.batch_size))

                        logging.info(
                            "Segmenting images using weights from final "
                            "recursion ({}) for slices {} to {}".format(
                                recursion, ind[0], ind[-1]))
                        feed_dict = {
                            images_pl: np.expand_dims(images[ind, ...], -1),
                            training_time_placeholder: False
                        }
                        output_data[final_dset][ind, ...] = scribbles[ind, ...]
                        # output_data[final_dset][ind, ...], _ = sess.run([mask_pl, softmax_pl],
                        #                                                 feed_dict=feed_dict)
                        output_data[final_dset].attrs.modify(
                            'processed_to', ind[-1] + 1)

                    #Calculate dices
                    _, dices, mean_dice, stdevs = calculate_dices(
                        np.array(output_data[final_dset]),
                        masks_gt,
                        exp_config.nlabels,
                        path=csv_path,
                        filename=dset_names[2, 2].format(recursion))
                    output_data[final_dset].attrs.modify(
                        'mean_dice', mean_dice)
                    output_data[final_dset].attrs.modify('dices', dices)
                    output_data[final_dset].attrs.modify('stdevs', stdevs)
                    output_data[final_dset].attrs.modify('processed', True)

        #print summaries:
        print("DICES:")
        l_str = " " * 40
        for i in range(1, exp_config.nlabels - 1):
            d_str = "Dice 0{} (stdev)".format(i)
            print(d_str)
            print(d_str.center(20))
            l_str += d_str.center(20)
        print(l_str + "       Mean Dice        ")

        print("      Random walker segmentations:".ljust(45) +
              dice_str(output_data['random_walk_mask']))
        for recursion in range(0, num_recursions + 1):
            print("   RECURSION {}".format(recursion))
            print("      Postprocessed input of recursion {}".format(
                recursion).ljust(45) +
                  dice_str(output_data[dset_names[0, 0].format(recursion)]))

            print("      Random walked input of recursion {}".format(
                recursion).ljust(45) +
                  dice_str(output_data[dset_names[0, 1].format(recursion)]))
            if recursion != num_recursions:
                print("      Output of recursion {}".format(recursion).ljust(
                    45) +
                      dice_str(output_data[dset_names[0,
                                                      2].format(recursion)]))

        #Handle last recursion seperately
        if not skip_final_recursion:
            print("      Output of recursion {}".format(recursion).ljust(45) +
                  dice_str(output_data[dset_names[0, 2].format(recursion)]))
            print(
                "         Weights for predicting final masks were taken from epoch {} of {}"
                .format(epoch, exp_config.epochs_per_recursion))

        #get graphs       Mean Dice
        if not batch_size is None:
            indices = np.sort(np.unique(sample(range(len(images)),
                                               batch_size)))
            logging.info(
                "Showing segmentation progression for slices randomly picked slices {}"
                .format(indices))
            descriptions = []
            for index in indices:
                descriptions.append("Slice {:04}".format(index))
        else:
            logging.info(
                "Showing segmentation progression for slices {0} [Best], {1} [Median] and {2} [Worst]"
                .format(indices[0], indices[1], indices[2]))
            descriptions = [
                'best slice ({:04})'.format(indices[0]),
                'median slice ({:04})'.format(indices[1]),
                'worst slice ({:04})'.format(indices[2])
            ]
            batch_size = 3

        for fig_idx in range(batch_size):
            graphs = np.zeros(
                (6 + num_recursions * 3, exp_config.image_size[0],
                 exp_config.image_size[1]))
            slice_idx = indices[fig_idx]
            #First get image, ground truth and random walked prediction
            graphs[1, ...] = masks_gt[slice_idx, ...]
            graphs[2, ...] = output_data['random_walk_mask'][slice_idx, ...]

            for recursion in range(num_recursions + 1):
                for i in range(3):
                    if i == 2 and recursion == num_recursions and skip_final_recursion:
                        break
                    dset_name = dset_names[0, i].format(recursion)
                    graphs[recursion * 3 + i + 3,
                           ...] = output_data[dset_name][slice_idx, ...]

            fig = plt.figure(fig_idx)
            fig.suptitle('Segmentation progress for {}'.format(
                descriptions[fig_idx]))
            #handle image seperately

            ax = fig.add_subplot(num_recursions + 2, 3, 1)
            ax.axis('off')
            ax.set_title('image')
            ax.imshow(np.squeeze(images[slice_idx, ...]),
                      cmap='gray',
                      vmin=0,
                      vmax=exp_config.nlabels - 1)

            for graph_idx in range(1, len(graphs)):
                ax = fig.add_subplot(num_recursions + 2, 3, graph_idx + 1)
                ax.axis('off')

                # This should be cleaned up
                if graph_idx == 1:
                    ax.set_title('ground truth')
                elif graph_idx == 2:
                    ax.set_title('random walker segmentation')
                elif graph_idx == 3:
                    ax.set_title('weak supervision')
                elif graph_idx == 4:
                    ax.set_title('ws random walked')
                else:
                    ax.set_title(
                        dset_names[2, graph_idx %
                                   3].format(int((graph_idx + 1) / 3) - 2))
                ax.imshow(np.squeeze(graphs[graph_idx, ...]),
                          cmap='jet',
                          vmin=0,
                          vmax=exp_config.nlabels - 1)

        plt.show()
Пример #17
0
def predict_segmentation(subject_name, image, normalize=True):

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[None] + list(exp_config.image_size) +
                                   [1],
                                   name='images')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, softmax, preds = model.predict_i2l(images_normalized,
                                                   exp_config,
                                                   training_pl=tf.constant(
                                                       False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            i2l_vars.append(v)
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create saver
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_normalizer = tf.train.Saver(var_list=normalization_vars)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        sess.run(init_ops)

        # ================================================================
        # Restore the segmentation network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the normalization network parameters
        # ================================================================
        if normalize is True:
            logging.info(
                '============================================================')
            path_to_model = os.path.join(
                sys_config.log_root, exp_config.expname_normalizer
            ) + '/subject_' + subject_name + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_score.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_normalizer.restore(sess, checkpoint_path)
            logging.info(
                '============================================================')

        # ================================================================
        # Make predictions for the image at the resolution of the image after pre-processing
        # ================================================================
        mask_predicted = []
        img_normalized = []

        for b_i in range(0, image.shape[0], 1):

            X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1)

            mask_predicted.append(sess.run(preds, feed_dict={images_pl: X}))
            img_normalized.append(
                sess.run(images_normalized, feed_dict={images_pl: X}))

        mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float)
        img_normalized = np.squeeze(np.array(img_normalized)).astype(float)

        sess.close()

        return mask_predicted, img_normalized
Пример #18
0
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True):

    nx, ny = config.image_size[:2]
    batch_size = 1
    num_channels = config.nlabels

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    # According to the experiment config, pick a model and predict the output
    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0
        scale_vector = [config.pixel_size[0] / target_resolution[0], config.pixel_size[1] / target_resolution[1]]

        path_img = os.path.join(input_folder, 'img')
        if gt_exists:
            path_mask = os.path.join(input_folder, 'mask')
        
        for folder in os.listdir(path_img):
            
            logging.info(' ----- Doing image: -------------------------')
            logging.info('Doing: %s' % folder)
            logging.info(' --------------------------------------------')
            folder_path = os.path.join(path_img, folder)   #ciclo su cartelle paz
            
            utils.makefolder(os.path.join(path_pred, folder))
            
            if os.path.isdir(folder_path):
                
                for phase in os.listdir(folder_path):    #ciclo su cartelle ED ES
                    
                    save_path = os.path.join(path_pred, folder, phase)
                    utils.makefolder(save_path)
                    
                    predictions = []
                    mask_arr = []
                    img_arr = []
                    masks = []
                    imgs = []
                    path = os.path.join(folder_path, phase)
                    for file in os.listdir(path):
                        img = plt.imread(os.path.join(path,file))
                        if config.standardize:
                            img = image_utils.standardize_image(img)
                        if config.normalize:
                            img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)
                        img_arr.append(img)
                    if  gt_exists:
                        for file in os.listdir(os.path.join(path_mask,folder,phase)):
                            mask_arr.append(plt.imread(os.path.join(path_mask,folder,phase,file)))
                    
                    img_arr = np.transpose(np.asarray(img_arr),(1,2,0))      # x,y,N
                    if  gt_exists:
                        mask_arr = np.transpose(np.asarray(mask_arr),(1,2,0))
                    
                    start_time = time.time()
                    
                    if config.data_mode == '2D':
                        
                        for zz in range(img_arr.shape[2]):
                            
                            slice_img = np.squeeze(img_arr[:,:,zz])
                            slice_rescaled = transform.rescale(slice_img,
                                                               scale_vector,
                                                               order=1,
                                                               preserve_range=True,
                                                               multichannel=False,
                                                               anti_aliasing=True,
                                                               mode='constant')
                                
                            slice_mask = np.squeeze(mask_arr[:, :, zz])
                            slice_cropped = read_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                            slice_cropped = np.float32(slice_cropped)
                            x = image_utils.reshape_2Dimage_to_tensor(slice_cropped)
                            imgs.append(np.squeeze(x))
                            if gt_exists:
                                mask_rescaled = transform.rescale(slice_mask,
                                                                  scale_vector,
                                                                  order=0,
                                                                  preserve_range=True,
                                                                  multichannel=False,
                                                                  anti_aliasing=True,
                                                                  mode='constant')

                                mask_cropped = read_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)
                                mask_cropped = np.asarray(mask_cropped, dtype=np.uint8)
                                y = image_utils.reshape_2Dimage_to_tensor(mask_cropped)
                                masks.append(np.squeeze(y))
                                
                            # GET PREDICTION
                            feed_dict = {
                            images_pl: x,
                            }
                                
                            mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)
                            
                            prediction_cropped = np.squeeze(logits_out[0,...])

                            # ASSEMBLE BACK THE SLICES
                            slice_predictions = np.zeros((nx,ny,num_channels))
                            slice_predictions = prediction_cropped
                            # RESCALING ON THE LOGITS
                            if gt_exists:
                                prediction = transform.resize(slice_predictions,
                                                              (nx, ny, num_channels),
                                                              order=1,
                                                              preserve_range=True,
                                                              anti_aliasing=True,
                                                              mode='constant')
                            else:
                                prediction = transform.rescale(slice_predictions,
                                                               (1.0/scale_vector[0], 1.0/scale_vector[1], 1),
                                                               order=1,
                                                               preserve_range=True,
                                                               multichannel=False,
                                                               anti_aliasing=True,
                                                               mode='constant')

                            prediction = np.uint8(np.argmax(prediction, axis=-1))
                            
                            predictions.append(prediction)
                            
       
                        predictions = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0))
                        masks = np.transpose(np.asarray(masks, dtype=np.uint8), (1,2,0))
                        imgs = np.transpose(np.asarray(imgs, dtype=np.float32), (1,2,0))                   

                            
                    # This is the same for 2D and 3D
                    if do_postprocessing:
                        predictions = image_utils.keep_largest_connected_components(predictions)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' % elapsed_time)

                    
                    # Save predicted mask
                    for ii in range(predictions.shape[2]):
                        image_file_name = os.path.join('paz', str(ii).zfill(3) + '.png')
                        cv2.imwrite(os.path.join(save_path , image_file_name), np.squeeze(predictions[:,:,ii]))
                                        
                    if gt_exists:
Пример #19
0
def main(exp_config, batch_size=3):

    # Load data
    data = h5py.File(sys_config.project_root + exp_config.scribble_data, 'r')

    slices = np.random.randint(low=0,
                               high=data['images_test'].shape[0],
                               size=batch_size)
    slices = np.sort(np.unique(slices))
    slices = [80, 275, 370]
    batch_size = len(slices)
    images = data['images_test'][slices, ...]
    masks = data['masks_test'][slices, ...]
    #masks[masks == 0] = 4

    num_recursions = most_recent_recursion(model_path)

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle,
                                        exp_config.nlabels)
    #mask_fs_pl, softmax_fs_pl = model.predict(images_pl,  unet2D_bn_modified, 4)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    predictions = np.zeros([batch_size] + list(exp_config.image_size) +
                           [num_recursions + 1])

    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }
    path = '/scratch_net/'
    with tf.Session() as sess:
        sess.run(init)
        pred_size = 0
        for recursion in range(num_recursions + 1):
            try:
                try:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model_best_dice.ckpt'.format(recursion))
                except:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model.ckpt'.format(recursion))
                saver.restore(sess, checkpoint_path)
                mask_out, _ = sess.run([mask_pl, softmax_pl],
                                       feed_dict=feed_dict)
                for mask in range(batch_size):
                    predictions[
                        mask, ...,
                        pred_size] = image_utils.keep_largest_connected_components(
                            np.squeeze(mask_out[mask, ...]))
                print("Classified for recursion {}".format(recursion))
                pred_size += 1
            except Exception as e:
                print(e)

    num_recursions = pred_size
    fig = plt.figure()
    num_cols = num_recursions + 3
    #RW:
    path = base_path + "/poster/"
    for recursion in range(num_recursions):
        predictions[...,
                    recursion] = segment(images,
                                         np.squeeze(predictions[...,
                                                                recursion]),
                                         beta=exp_config.rw_beta,
                                         threshold=0)

    for r in range(batch_size):
        #Add the image
        # ax = fig.add_subplot(batch_size, num_cols, 1 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(images[r, ...]), cmap='gray')
        image_utils.print_grayscale(
            np.squeeze(images[r, ...]), path,
            '{}_{}_image'.format(exp_config.experiment_name, slices[r]))
        #Add the mask
        # ax = fig.add_subplot(batch_size, num_cols, 2 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(masks[r, ...]), vmin=0, vmax=4, cmap='jet')
        image_utils.print_coloured(
            np.squeeze(masks[r, ...]), path,
            '{}_{}_gt'.format(exp_config.experiment_name, slices[r]))

        #predictions[r, ...] = segment(images, np.squeeze(predictions[r, ...]), beta=exp_config.rw_beta, threshold=0)
        for recursion in range(num_recursions):
            #Add each prediction
            image_utils.print_coloured(
                np.squeeze(predictions[r, ..., recursion]), path,
                '{}_{}_pred_r{}'.format(exp_config.experiment_name, slices[r],
                                        recursion))
            #ax = fig.add_subplot(batch_size, num_cols, 3 + recursion + r*num_cols)
            #ax.axis('off')
            #ax.imshow(np.squeeze(predictions[r, ..., recursion]), vmin=0, vmax=4, cmap='jet')

    while True:
        plt.axis('off')
        plt.show()
Пример #20
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    train_on_all_data = config.train_on_all_data

    # Load data
    data = read_data.load_and_maybe_process_data(
        input_folder=config.data_root,
        preprocessing_folder=config.preprocessing_folder,
        mode=config.data_mode,
        size=config.image_size,
        target_resolution=config.target_resolution,
        force_overwrite=False)

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']

    if not train_on_all_data:
        images_val = data['images_val']
        labels_val = data['masks_val']

    if config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    #   if config.prob:   #if prob is not 0
    #       logging.info('Before data_augmentation the number of training images is:')
    #       logging.info(images_train.shape[0])
    #       #augmentation
    #       image_aug, label_aug = aug.augmentation_function(images_train,labels_train)

    #num_aug = image_aug.shape[0]
    # id images augmented will be b'0.0'
    #id_aug = np.zeros([num_aug,]).astype('|S9')
    #concatenate
    #id_train = np.concatenate((id__train,id_aug))
    #       images_train = np.concatenate((images_train,image_aug))
    #       labels_train = np.concatenate((labels_train,label_aug))

    #       logging.info('After data_augmentation the number of training images is:')
    #       logging.info(images_train.shape[0])
    #   else:
    #       logging.info('No data_augmentation. Number of training images is:')
    #       logging.info(images_train.shape[0])

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [config.batch_size] + list(
            config.image_size) + [1]
        mask_tensor_shape = [config.batch_size] + list(config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        if (config.experiment_name == 'unet2D_valid'
                or config.experiment_name == 'unet2D_same'
                or config.experiment_name == 'unet2D_same_mod'
                or config.experiment_name == 'unet2D_light'
                or config.experiment_name == 'Dunet2D_same_mod'
                or config.experiment_name == 'Dunet2D_same_mod2'
                or config.experiment_name == 'Dunet2D_same_mod3'):
            logits = model.inference(images_pl, config, training=training_pl)
        elif config.experiment_name == 'ENet':
            with slim.arg_scope(
                    model_structure.ENet_arg_scope(weight_decay=2e-4)):
                logits = model_structure.ENet(
                    images_pl,
                    num_classes=config.nlabels,
                    batch_size=config.batch_size,
                    is_training=True,
                    reuse=None,
                    num_initial_blocks=1,
                    stage_two_repeat=2,
                    skip_connections=config.skip_connections)
        else:
            logging.warning('invalid experiment_name!')

        logging.info('images_pl shape')
        logging.info(images_pl.shape)
        logging.info('labels_pl shape')
        logging.info(labels_pl.shape)
        logging.info('logits shape:')
        logging.info(logits.shape)
        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=config.nlabels,
                                    loss_type=config.loss_type,
                                    weight_decay=config.weight_decay
                                    )  # second output is unregularised loss

        # record how Total loss and weight decay change over time
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if config.momentum is not None:
            train_op = model.training_step(loss,
                                           config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=config.momentum)
        else:
            train_op = model.training_step(loss, config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        # loss and dice on a minibatch
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=config.nlabels,
                                     loss_type=config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        configP = tf.ConfigProto()
        configP.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        configP.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=configP)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(
                    images_train,
                    labels_train,
                    batch_size=config.batch_size,
                    augment_batch=config.augment_batch):

                if config.warmup_training:
                    if step < 50:
                        curr_lr = config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = config.learning_rate

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches)
                if y.shape[0] < config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 20 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.3f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if (step + 1) % config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl,
                                           images_train, labels_train,
                                           config.batch_size)

                    train_summary_msg = sess.run(train_summary,
                                                 feed_dict={
                                                     train_error_: train_loss,
                                                     train_dice_: train_dice
                                                 })
                    summary_writer.add_summary(train_summary_msg, step)

                    loss_history.append(train_loss)
                    if len(loss_history) > 5:
                        loss_history.pop(0)
                        loss_gradient = (loss_history[-5] -
                                         loss_history[-1]) / 2

                    logging.info('loss gradient is currently %f' %
                                 loss_gradient)

                    if config.schedule_lr and loss_gradient < config.schedule_gradient_threshold:
                        logging.warning('Reducing learning rate!')
                        curr_lr /= 10.0
                        logging.info('Learning rate changed to: %f' % curr_lr)

                        # reset loss history to give the optimisation some time to start decreasing again
                        loss_gradient = np.inf
                        loss_history = []

                    if train_loss <= last_train:  # best_train:
                        no_improvement_counter = 0
                        logging.info('Decrease in training error!')
                    else:
                        no_improvement_counter = no_improvement_counter + 1
                        logging.info(
                            'No improvment in training error for %d steps' %
                            no_improvement_counter)

                    last_train = train_loss

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % config.val_eval_frequency == 0:

                    checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)
                    # Evaluate against the training set.

                    if not train_on_all_data:

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss,
                         val_dice] = do_eval(sess, eval_loss, images_pl,
                                             labels_pl, training_pl,
                                             images_val, labels_val,
                                             config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(log_dir,
                                                     'model_best_dice.ckpt')
                            filelist = glob.glob(
                                os.path.join(log_dir, 'model_best_dice*'))
                            for file in filelist:
                                os.remove(file)
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                                % val_dice)

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            filelist = glob.glob(
                                os.path.join(log_dir, 'model_best_xent*'))
                            for file in filelist:
                                os.remove(file)
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                step += 1

            # end epoch
            if (epoch + 1) % config.epoch_freq == 0:
                curr_lr = curr_lr * 0.98
            logging.info('Learning rate: %f' % curr_lr)
        sess.close()
    data.close()
Пример #21
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    if hasattr(exp_config, 'train_on_all_data'):
        train_on_all_data = exp_config.train_on_all_data
    else:
        train_on_all_data = False

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=(not train_on_all_data))

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']

    if not train_on_all_data:
        images_val = data['images_test']
        labels_val = data['masks_test']

    if exp_config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * exp_config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        logits = model.inference(images_pl, exp_config, training=training_pl)

        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=exp_config.nlabels,
                                    loss_type=exp_config.loss_type,
                                    weight_decay=exp_config.weight_decay
                                    )  # second output is unregularised loss

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if exp_config.momentum is not None:
            train_op = model.training_step(loss,
                                           exp_config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=exp_config.momentum)
        else:
            train_op = model.training_step(loss, exp_config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=exp_config.nlabels,
                                     loss_type=exp_config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=config)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = exp_config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(exp_config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(
                    images_train,
                    labels_train,
                    batch_size=exp_config.batch_size,
                    augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning eventhough the code has
                # crashed.
                # for batch in BackgroundGenerator(iterate_minibatches(images_train,
                #                                                      labels_train,
                #                                                      batch_size=exp_config.batch_size,
                #                                                      augment_batch=exp_config.augment_batch)):

                if exp_config.warmup_training:
                    if step < 50:
                        curr_lr = exp_config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = exp_config.learning_rate

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 10 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if (step + 1) % exp_config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl,
                                           images_train, labels_train,
                                           exp_config.batch_size)

                    train_summary_msg = sess.run(train_summary,
                                                 feed_dict={
                                                     train_error_: train_loss,
                                                     train_dice_: train_dice
                                                 })
                    summary_writer.add_summary(train_summary_msg, step)

                    loss_history.append(train_loss)
                    if len(loss_history) > 5:
                        loss_history.pop(0)
                        loss_gradient = (loss_history[-5] -
                                         loss_history[-1]) / 2

                    logging.info('loss gradient is currently %f' %
                                 loss_gradient)

                    if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                        logging.warning('Reducing learning rate!')
                        curr_lr /= 10.0
                        logging.info('Learning rate changed to: %f' % curr_lr)

                        # reset loss history to give the optimisation some time to start decreasing again
                        loss_gradient = np.inf
                        loss_history = []

                    if train_loss <= last_train:  # best_train:
                        logging.info('Decrease in training error!')
                    else:
                        logging.info(
                            'No improvment in training error for %d steps' %
                            no_improvement_counter)

                    last_train = train_loss

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % exp_config.val_eval_frequency == 0:

                    checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)
                    # Evaluate against the training set.

                    if not train_on_all_data:

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss,
                         val_dice] = do_eval(sess, eval_loss, images_pl,
                                             labels_pl, training_pl,
                                             images_val, labels_val,
                                             exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(log_dir,
                                                     'model_best_dice.ckpt')
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                                % val_dice)

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                step += 1

        sess.close()
    data.close()
Пример #22
0
def main(exp_config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False)

    batch_size = 1

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    #mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels)
    training_time_placeholder = tf.placeholder(tf.bool, shape=[])
    logits = model.inference(images_pl,
                             exp_config.model_handle,
                             training=training_time_placeholder,
                             nlabels=exp_config.nlabels)
    softmax_pl = tf.nn.softmax(logits)
    threshold = tf.constant(0.95, dtype=tf.float32)
    s = tf.multiply(tf.ones(shape=[1, 212, 212, 1]), threshold)
    softmax_pl = tf.concat([s, softmax_pl[..., 1:]], axis=-1)
    mask_pl = tf.arg_max(logits, dimension=-1)

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)

        #checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'recursion_1_model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        for i in range(10, 20):
            ind = i  #np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind, ...]
            y = data['masks_test'][ind, ...]

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)

            feed_dict = {images_pl: x, training_time_placeholder: False}

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl],
                                             feed_dict=feed_dict)

            #postprocessing

            fig = plt.figure()
            ax1 = fig.add_subplot(251)
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(252)
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(253)
            ax3.imshow(np.squeeze(mask_out))

            ax5 = fig.add_subplot(256)
            ax5.imshow(np.squeeze(softmax_out[..., 0]))
            ax6 = fig.add_subplot(257)
            ax6.imshow(np.squeeze(softmax_out[..., 1]))
            ax7 = fig.add_subplot(258)
            ax7.imshow(np.squeeze(softmax_out[..., 2]))
            ax8 = fig.add_subplot(259)
            ax8.imshow(np.squeeze(softmax_out[..., 3]))
            ax8 = fig.add_subplot(2, 5, 10)
            ax8.imshow(np.squeeze(softmax_out[..., 4]))

            plt.show()
def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, recursion=None):

    print("KOD YENİ")
    dices = []
    images, labels = read_data('/scratch/cany/scribble/scribble_data/prostate_divided.h5')
    num_images = images.shape[0]
    print(str(num_images))
    print(str(images.shape))
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl,logits = model.predict_logits(images_pl, exp_config.model_handle, exp_config.nlabels)
    
    
    mask_tensor_shape = [batch_size] + list(exp_config.image_size)
  
#    images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
    labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels')


    # Add to the Graph the Ops for loss calculation.

#    eval_val_loss = model.evaluation(logits,
#                                     labels_placeholder,
#                                     images_pl,
#                                     nlabels=exp_config.nlabels,
#                                     loss_type=exp_config.loss_type,
#                                     weak_supervision=True,
#                                     cnn_threshold=exp_config.cnn_threshold,
#                                     include_bg=False)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)
        if recursion is None:
            checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        else:
            try:
                checkpoint_path = utils.get_latest_model_checkpoint_path(model_path,
                                                                         'recursion_{}_model_best_dice.ckpt'.format(recursion))
            except:
                checkpoint_path = utils.get_latest_model_checkpoint_path(model_path,
                                                                         'recursion_{}_model.ckpt'.format(recursion))

        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        for k in range(num_images):
            network_input = np.expand_dims(np.expand_dims(images[k,:,:],2),0)
            mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict={images_pl: network_input})
            prediction_cropped = np.squeeze(logits_out[0, ...])
            
            # ASSEMBLE BACK THE SLICES
            prediction_arr = np.uint8(np.argmax(prediction_cropped, axis=-1))
            


#            prediction_arr = np.squeeze(np.transpose(np.asarray(prediction, dtype=np.uint8), (1,2,0)))
#    
                           
            mask = labels[k,:,:]
    
                            # This is the same for 2D and 3D again
            if do_postprocessing:
                print("Entered post processing " + str(True))
                prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)
    
            
            # Save predicted mask
            out_file_name = os.path.join(output_folder, 'prediction', 'patient' + str(k) +'.nii.gz')

            logging.info('saving to: %s' % out_file_name)
            save_nii(out_file_name, prediction_arr)
    
            # Save GT image
            gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + str(k) + '.nii.gz')
            logging.info('saving to: %s' % gt_file_name)
            save_nii(gt_file_name, np.uint8(mask))
    
#            # Save difference mask between predictions and ground truth
#            difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0])
#            difference_mask = np.asarray(difference_mask, dtype=np.uint8)
#            diff_file_name = os.path.join(output_folder,
#                                          'difference',
#                                          'patient' + str(k) + '.nii.gz')
#            logging.info('saving to: %s' % diff_file_name)
#            save_nii(diff_file_name, difference_mask)
    
            # Save image data to the same folder for convenience
            image_file_name = os.path.join(output_folder, 'image',
                                    'patient' + str(k)  + '.nii.gz')
            logging.info('saving to: %s' % image_file_name)
            save_nii(image_file_name, images[k,:,:])

#            feed_dict = { images_pl: network_input,
#                          labels_placeholder: np.expand_dims(np.squeeze(labels[k,:,:]),0),
#                    }
#
#            closs, cdice = sess.run(eval_val_loss, feed_dict=feed_dict)

#            print(str(prediction_arr.shape))
#            tempp= np.expand_dims(np.squeeze(labels[k,:,:]),0)
#            print(str(tempp.shape))
#            qwe=tf.one_hot(np.uint8(np.squeeze(labels[k,:,:])), depth=4)
#            print(str(sess.run(tf.shape(qwe))))
#            tempp2 = tf.one_hot(prediction_arr, depth=4)
#            print(str(sess.run(tf.shape(tempp2))))
            cdice = sess.run(get_dice(tf.one_hot(np.uint8(prediction_arr), depth=4),np.uint8(np.squeeze(labels[k,:,:])),4))
            print(str(cdice))
#            [val_loss, val_dice] = do_eval(sess,
#                                                       eval_val_loss,
#                                                       images_placeholder,
#                                                       labels_placeholder,
#                                                     network_input,
#                                                       np.expand_dims(np.squeeze(labels[k,:,:]),0),
#                                                       exp_config.batch_size)
            dices.append(cdice)
    print("Average Dice : " + str(np.mean(dices)))
    return init_iteration
def run_training():

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # Determine the data set
    # ============================
    target_data = True

    # ============================
    # Load data
    # ============================

    if target_data:
        hcp = True
        if hcp:
            logging.info(
                '============================================================')
            logging.info('Loading data...')
            logging.info('Reading HCP - 3T - T2 images...')
            logging.info('Data root directory: ' +
                         sys_config.orig_data_root_hcp)
            imts, gtts = data_hcp.load_data(sys_config.orig_data_root_hcp,
                                            sys_config.preproc_folder_hcp,
                                            'T2', 1, exp_config.test_size + 1)
            logging.info('Test Images: %s' % str(
                imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info('Test Labels: %s' % str(
                gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info(
                '============================================================')
        else:
            logging.info(
                '============================================================')
            logging.info('Loading data...')
            logging.info('Reading ABIDE caltech...')
            logging.info('Data root directory: ' +
                         sys_config.orig_data_root_abide)
            imts, gtts = data_abide.load_data(sys_config.orig_data_root_abide,
                                              sys_config.preproc_folder_abide,
                                              1, exp_config.test_size + 1)
            logging.info('Test Images: %s' % str(
                imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info('Test Labels: %s' % str(
                gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info(
                '============================================================')
    else:
        logging.info(
            '============================================================')
        logging.info('Loading data...')
        logging.info('Reading HCP - 3T - T1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        imts, gtts = data_hcp.load_data(
            sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1',
            exp_config.train_size + exp_config.val_size + 1,
            exp_config.train_size + exp_config.val_size +
            exp_config.test_size + 1)
        logging.info(
            'Test Images: %s' %
            str(imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
        logging.info(
            'Test Labels: %s' %
            str(gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
        logging.info(
            '============================================================')

    # ============================
    # Remove exclusively black images
    # ============================
    mask = ~(imts == 0).all(axis=(1, 2))
    imts = imts[mask]
    gtts = gtts[mask]

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        # Placeholders for the images and labels
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        images_pl = tf.compat.v1.placeholder(tf.float32,
                                             shape=image_tensor_shape,
                                             name='images')
        labels_tensor_shape_segmentation = [exp_config.batch_size] + list(
            exp_config.image_size)
        labels_tensor_shape_self_supervised = [exp_config.batch_size] + [
            exp_config.num_rotation_values
        ]
        labels_segmentation_pl = tf.compat.v1.placeholder(
            tf.uint8,
            shape=labels_tensor_shape_segmentation,
            name='labels_segmentation')
        labels_self_supervised_pl = tf.compat.v1.placeholder(
            tf.float16,
            shape=labels_tensor_shape_self_supervised,
            name='labels_self_supervised')
        # Placeholders for the learning rate and to indicate whether the model is being trained or tested
        learning_rate_pl = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[],
                                                    name='learning_rate')
        training_pl = tf.compat.v1.placeholder(tf.bool,
                                               shape=[],
                                               name='training_or_testing')

        # ================================================================
        # Define the image normalization function - these parameters will be updated for each test image
        # ================================================================
        image_normalized = exp_config.model_handle_normalizer(
            images_pl, image_tensor_shape)

        # ================================================================
        # Define the segmentation network here
        # ================================================================
        logits_segmentation = exp_config.model_handle_segmentation(
            image_normalized, image_tensor_shape, training_pl,
            exp_config.nlabels)

        # ================================================================
        # Define the self-supervised network here
        # ================================================================
        logits_self_supervised = exp_config.model_handle_rotation(
            image_normalized, image_tensor_shape, training_pl)

        # ================================================================
        # determine trainable variables
        # ================================================================
        segmentation_vars = []
        self_supervised_vars = []
        test_time_opt_vars = []

        for v in tf.compat.v1.trainable_variables():
            var_name = v.name
            if 'rotation' in var_name:
                self_supervised_vars.append(v)
            elif 'segmentation' in var_name:
                segmentation_vars.append(v)
            elif 'normalization' in var_name:
                test_time_opt_vars.append(v)

        if exp_config.debug is True:
            logging.info('================================')
            logging.info('List of trainable variables in the graph:')
            for v in tf.compat.v1.trainable_variables():
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all segmentation variables:')
            for v in segmentation_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all self-supervised variables:')
            for v in self_supervised_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all test time variables:')
            for v in test_time_opt_vars:
                logging.info(v.name)

        # ================================================================
        # Add ops for calculation of the training loss
        # ================================================================
        loss_self_supervised = tf.reduce_mean(
            exp_config.loss_handle_self_supervised(labels_self_supervised_pl,
                                                   logits_self_supervised))
        tf.compat.v1.summary.scalar('loss_self_supervised',
                                    loss_self_supervised)

        # ================================================================
        # Add optimization ops.
        # ================================================================
        train_op_test_time = model.training_step(
            loss_self_supervised, test_time_opt_vars,
            exp_config.optimizer_handle_test_time, learning_rate_pl)

        # ================================================================
        # Add ops for model evaluation
        # ================================================================
        eval_loss_segmentation = model.evaluation(
            logits_segmentation,
            labels_segmentation_pl,
            images_pl,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type)
        eval_loss_self_supervised = model.evaluation_rotate(
            logits_self_supervised, labels_self_supervised_pl)

        # ================================================================
        # Build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        summary = tf.compat.v1.summary.merge_all()

        # ================================================================
        # Add init ops
        # ================================================================
        init_g = tf.compat.v1.global_variables_initializer()
        init_l = tf.compat.v1.local_variables_initializer()

        # ================================================================
        # Find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.compat.v1.report_uninitialized_variables()

        # ================================================================
        # create savers for each domain
        # ================================================================
        max_to_keep = 15
        saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep)
        saver_best_da = tf.compat.v1.train.Saver()
        saver_seg = tf.compat.v1.train.Saver(var_list=segmentation_vars)
        saver_ss = tf.compat.v1.train.Saver(var_list=self_supervised_vars)

        # ================================================================
        # Create session
        # ================================================================
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.compat.v1.Session(config=config)

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the test errors
        # ================================================================
        ts_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='ts_error_seg')
        ts_error_summary_seg = tf.compat.v1.summary.scalar(
            'test/loss_seg', ts_error_seg)
        ts_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='ts_dice_seg')
        ts_dice_summary_seg = tf.compat.v1.summary.scalar(
            'test/dice_seg', ts_dice_seg)
        ts_summary_seg = tf.compat.v1.summary.merge(
            [ts_error_summary_seg, ts_dice_summary_seg])

        ts_error_ss = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='ts_error_ss')
        ts_error_summary_ss = tf.compat.v1.summary.scalar(
            'test/loss_ss', ts_error_ss)
        ts_acc_ss = tf.compat.v1.placeholder(tf.float32,
                                             shape=[],
                                             name='ts_acc_ss')
        ts_acc_summary_ss = tf.compat.v1.summary.scalar(
            'test/acc_ss', ts_acc_ss)
        ts_summary_ss = tf.compat.v1.summary.merge(
            [ts_error_summary_ss, ts_acc_summary_ss])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.compat.v1.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_g)
        sess.run(init_l)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            logging.info(v)

        # ================================================================
        # restore shared weights
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('Restore segmentation...')
        model_path = os.path.join(sys_config.log_root,
                                  'Initial_training_segmentation')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'models/best_dice.ckpt')
        logging.info('Restroring session from: %s' % checkpoint_path)
        saver_seg.restore(sess, checkpoint_path)

        logging.info(
            '============================================================')
        logging.info('Restore self-supervised...')
        model_path = os.path.join(sys_config.log_root,
                                  'Initial_training_rotation')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'models/best_dice.ckpt')
        logging.info('Restroring session from: %s' % checkpoint_path)
        saver_ss.restore(sess, checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_da = 0

        # ================================================================
        # run training epochs
        # ================================================================
        for epoch in range(exp_config.max_epochs):

            logging.info(
                '============================================================')
            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(imts,
                                             gtts,
                                             batch_size=exp_config.batch_size):
                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y_seg, y_ss = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y_seg.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_self_supervised_pl: y_ss,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # update vars
                # ===========================
                _, loss_value = sess.run(
                    [train_op_test_time, loss_self_supervised],
                    feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.2f (%.3f sec for the last step)' %
                        (step + 1, loss_value, duration))

                    # ===========================
                    # print values of a parameter (to debug)
                    # ===========================
                    if exp_config.debug is True:
                        var_value = tf.compat.v1.get_collection(
                            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES
                        )[0].eval(
                            session=sess
                        )  # can add name if you only want parameters with a specific name
                        logging.info('value of one of the parameters %f' %
                                     var_value[0, 0, 0, 0])

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # compute the loss on the entire test set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:

                    logging.info('Test Data Eval:')
                    [test_loss_seg, test_dice_seg, test_loss_ss, test_acc_ss
                     ] = do_eval(sess, eval_loss_segmentation,
                                 eval_loss_self_supervised, images_pl,
                                 labels_segmentation_pl,
                                 labels_self_supervised_pl, training_pl, imts,
                                 gtts, exp_config.batch_size)

                    ts_summary_msg = sess.run(ts_summary_seg,
                                              feed_dict={
                                                  ts_error_seg: test_loss_seg,
                                                  ts_dice_seg: test_dice_seg
                                              })
                    summary_writer.add_summary(ts_summary_msg, step)
                    ts_summary_msg = sess.run(ts_summary_ss,
                                              feed_dict={
                                                  ts_error_ss: test_loss_ss,
                                                  ts_acc_ss: test_acc_ss
                                              })
                    summary_writer.add_summary(ts_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:

                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                step += 1

        sess.close()
Пример #25
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)
    already_created_recursion = True
    print("ALready created recursion : " + str(already_created_recursion))
    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(
        scribble_file=sys_config.project_root + exp_config.scribble_data,
        target_folder=log_dir,
        percent_full_sup=exp_config.percent_full_sup,
        scr_ratio=exp_config.length_ratio)
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir, 'recursion_{}_model.ckpt'.format(recursion))

                except:
                    print("EXCEPTE GİRDİ")
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir,
                        'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(
                    init_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(init_step /
                                  (len(base_data['images_train']) / 4))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning(
                    '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
                )
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )

        if loaded_previous_recursion:
            logging.info(
                "Data file exists for recursion {} "
                "but checkpoints only present up to recursion {}".format(
                    recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")

            #            if not already_created_recursion:
            #
            #                recursion_data = acdc_data.load_different_recursion(recursion_data, -1)
            #                recursion-=1
            #            else:
            start_epoch = 0
            init_step = 0
        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        #        with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32,
                                                shape=image_tensor_shape,
                                                name='images')
            labels_placeholder = tf.placeholder(tf.uint8,
                                                shape=mask_tensor_shape,
                                                name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])
            keep_prob = tf.placeholder(tf.float32, shape=[])
            crf_learning_rate_placeholder = tf.placeholder(tf.float32,
                                                           shape=[])
            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     keep_prob,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm
             ] = model.loss(logits,
                            labels_placeholder,
                            nlabels=exp_config.nlabels,
                            loss_type=exp_config.loss_type,
                            weight_decay=exp_config.weight_decay
                            )  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)

            # Add to the Graph the Ops that calculate and apply gradients.

            global_step = tf.Variable(0, name='global_step', trainable=False)

            crf_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='crf_scope')

            restore_var = [
                v for v in tf.all_variables() if v.name not in crf_variables
            ]

            global_step = tf.Variable(0, name='global_step', trainable=False)

            network_train_op = tf.train.AdamOptimizer(
                learning_rate=learning_rate_placeholder).minimize(
                    loss,
                    var_list=restore_var,
                    colocate_gradients_with_ops=True,
                    global_step=global_step)

            crf_train_op = tf.train.AdamOptimizer(
                learning_rate=crf_learning_rate_placeholder).minimize(
                    loss,
                    var_list=crf_variables,
                    colocate_gradients_with_ops=True,
                    global_step=global_step)

            eval_val_loss = model.evaluation(
                logits,
                labels_placeholder,
                images_placeholder,
                nlabels=exp_config.nlabels,
                loss_type=exp_config.loss_type,
                weak_supervision=True,
                cnn_threshold=exp_config.cnn_threshold,
                include_bg=False)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss',
                                                  val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge(
                [val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='train_error')
            train_error_summary = tf.summary.scalar('training_loss',
                                                    train_error_)

            train_dice_ = tf.placeholder(tf.float32,
                                         shape=[],
                                         name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice',
                                                   train_dice_)

            train_summary = tf.summary.merge(
                [train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)

            #            if continue_run:
            #                # Restore session
            #                saver.restore(sess, init_checkpoint_path)
            #            saver.restore(sess,"/scratch_net/biwirender02/cany/scribble/logdir/heart_dropout_rnn_exp/recursion_1_model_best_dice.ckpt-12699")

            init_step = 0
            recursion = 0
            start_epoch = 0
            #

            step = init_step
            curr_lr = exp_config.learning_rate / 10
            crf_curr_lr = 1e-07 / 10
            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

            # random walk - if it already has been random walked it won't redo
            if recursion == 0:
                recursion_data = acdc_data.random_walk_epoch(
                    recursion_data, exp_config.rw_beta,
                    exp_config.rw_threshold, exp_config.random_walk)
                print("Random walku geçti")
                #get ground truths
                labels_train = np.array(recursion_data['random_walked'])
            else:
                labels_train = np.array(recursion_data['predicted'])
            print("Start epoch : " + str(start_epoch) + " : max epochs : " +
                  str(exp_config.epochs_per_recursion))
            for epoch in range(start_epoch, exp_config.max_epochs):
                if (epoch % exp_config.epochs_per_recursion == 0
                        and epoch != 0):

                    #Have reached end of recursion
                    recursion_data = predict_next_gt(
                        data=recursion_data,
                        images_train=images_train,
                        images_placeholder=images_placeholder,
                        training_time_placeholder=training_time_placeholder,
                        keep_prob=keep_prob,
                        logits=logits,
                        sess=sess)

                    #                        recursion_data = postprocess_gt(data=recursion_data,
                    #                                                        images_train=images_train,
                    #                                                        scribbles_train=scribbles_train)
                    recursion += 1
                    # random walk - if it already has been random walked it won't redo
                    #                        recursion_data = acdc_data.random_walk_epoch(recursion_data,
                    #                                                                     exp_config.rw_beta,
                    #                                                                     exp_config.rw_threshold,
                    #                                                                     exp_config.random_walk)
                    #get ground truths
                    labels_train = np.array(recursion_data['predicted'])

                    #reinitialise savers - otherwise, no checkpoints will be saved for each recursion
                    saver = tf.train.Saver(max_to_keep=2)
                    saver_best_dice = tf.train.Saver(max_to_keep=2)
                    saver_best_xent = tf.train.Saver(max_to_keep=2)
                logging.info(
                    'Epoch {0} ({1} of {2} epochs for recursion {3})'.format(
                        epoch, 1 + epoch % exp_config.epochs_per_recursion,
                        exp_config.epochs_per_recursion, recursion))
                # for batch in iterate_minibatches(images_train,
                #                                  labels_train,
                #                                  batch_size=exp_config.batch_size,
                #                                  augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning even though the code has
                # crashed.

                for batch in BackgroundGenerator(
                        iterate_minibatches(
                            images_train,
                            labels_train,
                            batch_size=exp_config.batch_size,
                            augment_batch=exp_config.augment_batch)):

                    if exp_config.warmup_training:
                        if step < 50:
                            curr_lr = exp_config.learning_rate / 10.0
                        elif step == 50:
                            curr_lr = exp_config.learning_rate
                    if ((step % 3000 == 0) & (step > 0)):
                        curr_lr = curr_lr * 0.9
                        crf_curr_lr = crf_curr_lr * 0.9

                    start_time = time.time()

                    # batch = bgn_train.retrieve()
                    x, y = batch

                    # TEMPORARY HACK (to avoid incomplete batches
                    if y.shape[0] < exp_config.batch_size:
                        step += 1
                        continue

                    network_feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        learning_rate_placeholder: curr_lr,
                        keep_prob: 0.5,
                        training_time_placeholder: True
                    }

                    crf_feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        crf_learning_rate_placeholder: crf_curr_lr,
                        keep_prob: 1,
                        training_time_placeholder: True
                    }

                    if (step % 10 == 0):
                        _, loss_value = sess.run([crf_train_op, loss],
                                                 feed_dict=crf_feed_dict)
                    _, loss_value = sess.run([network_train_op, loss],
                                             feed_dict=network_feed_dict)
                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.
                        logging.info('Step %d: loss = %.6f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(
                            log_dir,
                            'recursion_{}_model.ckpt'.format(recursion))
                        saver.save(sess, checkpoint_file, global_step=step)
                        # Evaluate against the training set.

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss, val_dice
                         ] = do_eval(sess, eval_val_loss, images_placeholder,
                                     labels_placeholder,
                                     training_time_placeholder, keep_prob,
                                     images_val, labels_val,
                                     exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_dice.ckpt'.format(
                                    recursion))
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - {} - '
                                'Saving recursion_{}_model_best_dice.ckpt'.
                                format(val_dice, recursion))
                            text_file = open('val_results.txt', "a")
                            text_file.write("\nVal dice " + str(step) + " : " +
                                            str(val_dice))
                            text_file.close()
                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_xent.ckpt'.format(
                                    recursion))
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - {} - '
                                'Saving recursion_{}_model_best_xent.ckpt'.
                                format(val_loss, recursion))

                    step += 1

    except Exception:
        raise
Пример #26
0
                # ================================================================
                sess = tf.Session()

                # ====================================
                # Initialize
                # ====================================
                sess.run(init_g)
                sess.run(init_l)

                # ====================================
                # Restore training models
                # ====================================
                log_dir = os.path.join(sys_config.log_root,
                                       exp_config.experiment_name)
                path_to_model = log_dir
                checkpoint_path = utils.get_latest_model_checkpoint_path(
                    path_to_model, 'models/best_dice.ckpt')
                logging.info(
                    '========================================================')
                logging.info('Restoring the trained parameters from %s...' %
                             checkpoint_path)
                saver.restore(sess, checkpoint_path)

                # ================================
                # open a text file for writing the mean dice scores for each subject that is evaluated
                # ================================
                if test_dataset is 'freiburg':
                    subject_string = '000'
                results_file = open(
                    log_dir + '/results/' + test_dataset + '/' +
                    subject_string + '.txt', "w")
                results_file.write("================================== \n")
Пример #27
0
def run_training(log_dir,
                 image,
                 label,
                 atlas,
                 continue_run,
                 log_dir_first_TD_subject=''):

    # ============================
    # down sample the atlas - the losses will be evaluated in the downsampled space
    # ============================
    atlas_downsampled = rescale(atlas, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                order=1,
                                preserve_range=True,
                                multichannel=True,
                                mode='constant')
    atlas_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        atlas_downsampled, int(256 / exp_config.downsampling_factor_x))

    label_onehot = utils.make_onehot(label, exp_config.nlabels)
    label_onehot_downsampled = rescale(label_onehot, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                       order=1,
                                       preserve_range=True,
                                       multichannel=True,
                                       mode='constant')
    label_onehot_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        label_onehot_downsampled, int(256 / exp_config.downsampling_factor_x))

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1])
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[exp_config.batch_size] +
                                   list(exp_config.image_size) + [1],
                                   name='images')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl, exp_config, training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # By setting the 'training_pl' to false directly, the update ops for the moments in the BN layer are not created at all.
        # This allows grouping the update ops together with the optimizer training, while training the normalizer - in case the normalizer has BN.
        # ================================================================
        predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l(
            images_normalized,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # 3d prior
        # ================================================================
        labels_3d_1hot_shape = [1] + list(
            exp_config.image_size_downsampled) + [exp_config.nlabels]
        # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder
        predicted_seg_1hot_3d_pl = tf.placeholder(tf.float32,
                                                  shape=labels_3d_1hot_shape,
                                                  name='predicted_labels_3d')

        # denoise the noisy segmentation
        _, predicted_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l(
            predicted_seg_1hot_3d_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        l2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'l2l_mapper' in var_name:
                l2l_vars.append(v)

        # ================================================================
        # Make a list of trainable i2l vars. This will be used to compute gradients.
        # The other list contains trainable as well as non-trainable parameters. This is required for saving and loading all parameters, but runs into trouble when gradients are asked to be computed for non-trainable parameters
        # ================================================================
        i2l_vars_trainable = []
        for v in i2l_vars:
            if v.trainable is True:
                i2l_vars_trainable.append(v)

        # ================================================================
        # add ops for calculation of the prior loss - wrt an atlas or the outputs of the DAE
        # ================================================================
        prior_label_1hot_pl = tf.placeholder(
            tf.float32,
            shape=[exp_config.batch_size_downsampled] + list(
                (exp_config.image_size_downsampled[1],
                 exp_config.image_size_downsampled[2])) + [exp_config.nlabels],
            name='labels_prior')

        # down sample the predicted logits
        predicted_seg_logits_expanded = tf.expand_dims(predicted_seg_logits,
                                                       axis=0)
        # the 'upsample' function will actually downsample the predictions, as the scaling factors have been set appropriately
        predicted_seg_logits_downsampled = layers.bilinear_upsample3D_(
            predicted_seg_logits_expanded,
            name='downsampled_predictions',
            factor_x=1 / exp_config.downsampling_factor_x,
            factor_y=1 / exp_config.downsampling_factor_y,
            factor_z=1 / exp_config.downsampling_factor_z)
        predicted_seg_logits_downsampled = tf.squeeze(
            predicted_seg_logits_downsampled
        )  # the first axis was added only for the downsampling in 3d

        # compute the dice between the predictions and the prior in the downsampled space
        loss_op = model.loss(logits=predicted_seg_logits_downsampled,
                             labels=prior_label_1hot_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_prior,
                             mask_for_loss_within_mask=None,
                             are_labels_1hot=True)

        tf.summary.scalar('tr_losses/loss', loss_op)

        # ================================================================
        # one of the two prior losses will be used in the following manner:
        # the atlas prior will be used when the current prediction is deemed to be very far away from a reasonable solution
        # once a reasonable solution is reached, the dae prior will be used.
        # these 3d computations will be done outside the graph and will be passed via placeholders for logging in tensorboard
        # ================================================================
        lambda_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='lambda_prior_atlas')
        lambda_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='lambda_prior_dae')
        tf.summary.scalar('lambdas/prior_atlas', lambda_prior_atlas_pl)
        tf.summary.scalar('lambdas/prior_dae', lambda_prior_dae_pl)

        dice3d_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='dice3d_prior_atlas')
        dice3d_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='dice3d_prior_dae')
        dice3d_gt_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_gt')
        tf.summary.scalar('dice3d/prior_atlas', dice3d_prior_atlas_pl)
        tf.summary.scalar('dice3d/prior_dae', dice3d_prior_dae_pl)
        tf.summary.scalar('dice3d/gt', dice3d_gt_pl)

        # ================================================================
        # add optimization ops
        # ================================================================
        if exp_config.debug: print('creating training op...')

        # create an instance of the required optimizer
        optimizer = exp_config.optimizer_handle(learning_rate=learning_rate_pl)

        # initialize variable holding the accumlated gradients and create a zero-initialisation op
        accumulated_gradients = [
            tf.Variable(tf.zeros_like(var.initialized_value()),
                        trainable=False) for var in i2l_vars_trainable
        ]

        # accumulated gradients init op
        accumulated_gradients_zero_op = [
            ac.assign(tf.zeros_like(ac)) for ac in accumulated_gradients
        ]

        # calculate gradients and define accumulation op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            gradients = optimizer.compute_gradients(
                loss_op, var_list=i2l_vars_trainable)
            # compute_gradients return a list of (gradient, variable) pairs.
        accumulate_gradients_op = [
            ac.assign_add(gg[0])
            for ac, gg in zip(accumulated_gradients, gradients)
        ]

        # define the gradient mean op
        num_accumulation_steps_pl = tf.placeholder(
            dtype=tf.float32, name='num_accumulation_steps')
        accumulated_gradients_mean_op = [
            ag.assign(tf.divide(ag, num_accumulation_steps_pl))
            for ag in accumulated_gradients
        ]

        # reassemble the gradients in the [value, var] format and do define train op
        final_gradients = [(ag, gg[1])
                           for ag, gg in zip(accumulated_gradients, gradients)]
        train_op = optimizer.apply_gradients(final_gradients)

        # ================================================================
        # sequence of running opt ops:
        # 1. at the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
        # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
        # 3. at the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
        # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
        # ================================================================
        # ================================================================
        # previous train_op without accumulation of gradients
        # ================================================================
        # train_op = model.training_step(loss_op, normalization_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars = True)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the training errors
        # ================================================================
        prior_dae_dice = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='prior_dae_dice')
        prior_dae_dice_summary = tf.summary.scalar('test_img/prior_dae_dice',
                                                   prior_dae_dice)
        prior_dae_output_dice_wrt_gt = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_output_dice_wrt_gt')
        prior_dae_output_dice_wrt_gt_summary = tf.summary.scalar(
            'test_img/prior_dae_output_dice_wrt_gt',
            prior_dae_output_dice_wrt_gt)
        prior_atlas_dice = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='prior_atlas_dice')
        prior_atlas_dice_summary = tf.summary.scalar(
            'test_img/prior_atlas_dice', prior_atlas_dice)
        prior_dae_atlas_dice_ratio = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_atlas_dice_ratio')
        prior_dae_atlas_dice_ratio_summary = tf.summary.scalar(
            'test_img/prior_dae_atlas_dice_ratio', prior_dae_atlas_dice_ratio)
        prior_dice = tf.placeholder(tf.float32, shape=[], name='prior_dice')
        prior_dice_summary = tf.summary.scalar('test_img/prior_dice',
                                               prior_dice)
        gt_dice = tf.placeholder(tf.float32, shape=[], name='gt_dice')
        gt_dice_summary = tf.summary.scalar('test_img/gt_dice', gt_dice)

        # ================================================================
        # create savers
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_l2l = tf.train.Saver(var_list=l2l_vars)
        saver_test_data = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)
        saver_best_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # add operations to compute dice between two 3d volumes
        # ================================================================
        pred_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='pred_3d')
        labl_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='labl_3d')
        atls_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='atls_3d')

        dice_3d_op_dae = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=labl_3d_1hot_pl)
        dice_3d_op_atlas = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=atls_3d_1hot_pl)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the prior network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_l2l.restore(sess, checkpoint_path)

        # ================================================================
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        if log_dir_first_TD_subject is not '':
            logging.info(
                '============================================================')
            path_to_model = log_dir_first_TD_subject + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_score.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_test_data.restore(sess, checkpoint_path)
            max_steps_tta = exp_config.max_steps
        else:
            max_steps_tta = 5 * exp_config.max_steps  # run the adaptation of the 1st TD subject for longer

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring normalization module from: %s' %
                         init_checkpoint_path)
            saver_test_data.restore(sess, init_checkpoint_path)

        # ================================================================
        # run training epochs
        # ================================================================
        step = init_step
        best_score = 0.0

        while (step < max_steps_tta):

            # ================================================
            # After every some epochs,
            # Get the prediction for the entire volume, evaluate it using the DAE.
            # Now, decide whether to use the DAE output or the atlas as the ground truth for the next update.
            # ================================================
            if (step == init_step) or (step % exp_config.check_ood_frequency is
                                       0):

                # ==================
                # 1. compute the current 3d segmentation prediction
                # ==================
                y_pred_soft = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred_soft.append(
                        sess.run(predicted_seg_softmax,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred_soft = np.squeeze(np.array(y_pred_soft)).astype(float)
                y_pred_soft = np.reshape(y_pred_soft, [
                    -1, y_pred_soft.shape[2], y_pred_soft.shape[3],
                    y_pred_soft.shape[4]
                ])

                # ==================
                # 2. downsample it. Let's call this guy 'A'
                # ==================
                y_pred_soft_downsampled = rescale(y_pred_soft, [
                    1 / exp_config.downsampling_factor_x,
                    1 / exp_config.downsampling_factor_y,
                    1 / exp_config.downsampling_factor_z
                ],
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant').astype(
                                                      np.float32)

                # ==================
                # 3. pass the downsampled prediction through the DAE and get its output 'B'
                # ==================
                feed_dict = {
                    predicted_seg_1hot_3d_pl:
                    np.expand_dims(y_pred_soft_downsampled, axis=0)
                }
                y_pred_noisy_denoised_softmax = np.squeeze(
                    sess.run(
                        predicted_seg_softmax_3d_noisy_autoencoded_softmax,
                        feed_dict=feed_dict)).astype(np.float16)
                y_pred_noisy_denoised = np.argmax(
                    y_pred_noisy_denoised_softmax, axis=-1)

                # ==================
                # 4. compute the dice between:
                #       a. 'A' (seg network prediction downsampled) and 'B' (dae network output)
                #       b. 'B' (dae network output) and downsampled gt labels (for debugging, to see if the dae output is close to the gt.)
                #       c. 'A' (seg network prediction downsampled) and 'C' (downsampled atlas)
                # ==================
                dAB = sess.run(dice_3d_op_dae,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   labl_3d_1hot_pl:
                                   y_pred_noisy_denoised_softmax
                               })
                dBgt = sess.run(dice_3d_op_dae,
                                feed_dict={
                                    pred_3d_1hot_pl:
                                    y_pred_noisy_denoised_softmax,
                                    labl_3d_1hot_pl: label_onehot_downsampled
                                })
                dAC = sess.run(dice_3d_op_atlas,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   atls_3d_1hot_pl: atlas_downsampled
                               })

                # ==================
                # 5. compute the ratio dice(AB) / dice(AC). pass the ratio through a threshold and decide whether to use the DAE or the atlas as the prior
                # ==================
                ratio_dice = dAB / (dAC + 1e-5)

                if exp_config.use_gt_for_tta is True:
                    target_labels_for_this_epoch = label_onehot_downsampled
                    prr = dBgt
                elif (ratio_dice > exp_config.dae_atlas_ratio_threshold) and (
                        dAC > exp_config.min_atlas_dice):
                    target_labels_for_this_epoch = y_pred_noisy_denoised_softmax
                    prr = dAB
                else:
                    target_labels_for_this_epoch = atlas_downsampled
                    prr = dAC

                # ==================
                # update losses on tensorboard
                # ==================
                summary_writer.add_summary(
                    sess.run(prior_dae_dice_summary,
                             feed_dict={prior_dae_dice: dAB}), step)
                summary_writer.add_summary(
                    sess.run(prior_dae_output_dice_wrt_gt_summary,
                             feed_dict={prior_dae_output_dice_wrt_gt: dBgt}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_atlas_dice_summary,
                             feed_dict={prior_atlas_dice: dAC}), step)
                summary_writer.add_summary(
                    sess.run(
                        prior_dae_atlas_dice_ratio_summary,
                        feed_dict={prior_dae_atlas_dice_ratio: ratio_dice}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_dice_summary, feed_dict={prior_dice: prr}),
                    step)

                # ==================
                # save best model so far
                # ==================
                if best_score < prr:
                    best_score = prr
                    best_file = os.path.join(log_dir, 'models/best_score.ckpt')
                    saver_best_loss.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best score (%f) at step %d -  Saving model.'
                        % (best_score, step))

                # ==================
                # dice wrt gt
                # ==================
                y_pred = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred.append(
                        sess.run(predicted_seg,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred = np.squeeze(np.array(y_pred)).astype(float)
                y_pred = np.reshape(y_pred,
                                    [-1, y_pred.shape[2], y_pred.shape[3]])
                dice_wrt_gt = met.f1_score(label.flatten(),
                                           y_pred.flatten(),
                                           average=None)
                summary_writer.add_summary(
                    sess.run(gt_dice_summary,
                             feed_dict={gt_dice: np.mean(dice_wrt_gt[1:])}),
                    step)

                # ==================
                # visualize results
                # ==================
                if step % exp_config.vis_frequency is 0:

                    # ===========================
                    # save checkpoint
                    # ===========================
                    logging.info(
                        '=============== Saving checkkpoint at step %d ... ' %
                        step)
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver_test_data.save(sess,
                                         checkpoint_file,
                                         global_step=step)

                    y_pred_noisy_denoised_upscaled = utils.crop_or_pad_volume_to_size_along_x(
                        rescale(y_pred_noisy_denoised, [
                            exp_config.downsampling_factor_x,
                            exp_config.downsampling_factor_y,
                            exp_config.downsampling_factor_z
                        ],
                                order=0,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant'),
                        image.shape[0]).astype(np.uint8)

                    x_norm = []
                    for batch in iterate_minibatches_images(
                            image, batch_size=exp_config.batch_size):
                        x = batch
                        x_norm.append(
                            sess.run(images_normalized,
                                     feed_dict={
                                         images_pl: x,
                                         training_pl: False
                                     }))
                    x_norm = np.squeeze(np.array(x_norm)).astype(float)
                    x_norm = np.reshape(x_norm,
                                        [-1, x_norm.shape[2], x_norm.shape[3]])

                    utils_vis.save_sample_results(
                        x=image,
                        x_norm=x_norm,
                        x_diff=x_norm - image,
                        y=y_pred,
                        y_pred_dae=y_pred_noisy_denoised_upscaled,
                        at=np.argmax(atlas, axis=-1),
                        gt=label,
                        savepath=log_dir + '/results/visualize_images/step' +
                        str(step) + '.png')

            # ================================================
            # Part of training ops sequence:
            # 1. At the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
            # ================================================
            sess.run(accumulated_gradients_zero_op)
            num_accumulation_steps = 0

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches_images_and_downsampled_labels(
                    images=image,
                    batch_size=exp_config.batch_size,
                    labels_downsampled=target_labels_for_this_epoch,
                    batch_size_downsampled=exp_config.batch_size_downsampled):

                x, y = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    prior_label_1hot_pl: y,
                    learning_rate_pl: exp_config.learning_rate,
                    training_pl: True
                }

                # ================================================
                # Part of training ops sequence:
                # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
                # ================================================
                sess.run(accumulate_gradients_op, feed_dict=feed_dict)
                num_accumulation_steps = num_accumulation_steps + 1

                step += 1

            # ================================================
            # Part of training ops sequence:
            # 3. At the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
            # ================================================
            sess.run(
                accumulated_gradients_mean_op,
                feed_dict={num_accumulation_steps_pl: num_accumulation_steps})

            # ================================================================
            # sequence of running opt ops:
            # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
            # ================================================================
            sess.run(train_op, feed_dict=feed_dict)

        # ================================================================
        # ================================================================
        sess.close()

    # ================================================================
    # ================================================================
    gc.collect()

    return 0
Пример #28
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    # Load data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=exp_config.data_root,
        preprocessing_folder=exp_config.preproc_folder,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        label_list=exp_config.label_list,
        offset=exp_config.offset,
        rescale_to_one=True,
        force_overwrite=False)

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    fieldstr_train = data['field_strength_train']
    labels_train = utils.fstr_to_label(fieldstr_train,
                                       exp_config.field_strength_list,
                                       exp_config.fs_label_list)
    ages_train = data['age_train']

    if exp_config.age_ordinal_regression:
        ages_train = utils.age_to_ordinal_reg_format(ages_train,
                                                     bins=exp_config.age_bins)
        ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train)
    else:
        ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins)
        ordinal_reg_weights = None

    images_val = data['images_val']
    fieldstr_val = data['field_strength_val']
    labels_val = utils.fstr_to_label(fieldstr_val,
                                     exp_config.field_strength_list,
                                     exp_config.fs_label_list)
    ages_val = data['age_val']

    if exp_config.age_ordinal_regression:
        ages_val = utils.age_to_ordinal_reg_format(ages_val,
                                                   bins=exp_config.age_bins)
    else:
        ages_val = utils.age_to_bins(ages_val, bins=exp_config.age_bins)

    if exp_config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * exp_config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info('TRAINING')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)
    logging.info('VALIDATiON')
    logging.info(' - Images:')
    logging.info(images_val.shape)
    logging.info(images_val.dtype)
    logging.info(' - Labels:')
    logging.info(labels_val.shape)
    logging.info(labels_val.dtype)

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        labels_tensor_shape = [exp_config.batch_size]

        if exp_config.age_ordinal_regression:
            ages_tensor_shape = [
                exp_config.batch_size,
                len(exp_config.age_bins)
            ]
        else:
            ages_tensor_shape = [exp_config.batch_size]

        images_placeholder = tf.placeholder(tf.float32,
                                            shape=image_tensor_shape,
                                            name='images')
        diag_placeholder = tf.placeholder(tf.uint8,
                                          shape=labels_tensor_shape,
                                          name='labels')
        ages_placeholder = tf.placeholder(tf.uint8,
                                          shape=ages_tensor_shape,
                                          name='ages')

        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   shape=[],
                                                   name='learning_rate')
        training_time_placeholder = tf.placeholder(tf.bool,
                                                   shape=[],
                                                   name='training_time')

        tf.summary.scalar('learning_rate', learning_rate_placeholder)

        # Build a Graph that computes predictions from the inference model.
        diag_logits, ages_logits = exp_config.clf_model_handle(
            images_placeholder,
            nlabels=exp_config.nlabels,
            training=training_time_placeholder,
            n_age_thresholds=len(exp_config.age_bins),
            bn_momentum=exp_config.bn_momentum)

        # Add to the Graph the Ops for loss calculation.

        [loss, diag_loss, age_loss, weights_norm
         ] = model_mt.loss(diag_logits,
                           ages_logits,
                           diag_placeholder,
                           ages_placeholder,
                           nlabels=exp_config.nlabels,
                           weight_decay=exp_config.weight_decay,
                           diag_weight=exp_config.diag_weight,
                           age_weight=exp_config.age_weight,
                           use_ordinal_reg=exp_config.age_ordinal_regression,
                           ordinal_reg_weights=ordinal_reg_weights)

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('diag_loss', diag_loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        if exp_config.momentum is not None:
            optimiser = exp_config.optimizer_handle(
                learning_rate=learning_rate_placeholder,
                momentum=exp_config.momentum)
        else:
            optimiser = exp_config.optimizer_handle(
                learning_rate=learning_rate_placeholder)

        # create a copy of all trainable variables with `0` as initial values
        t_vars = tf.global_variables()  #tf.trainable_variables()
        accum_tvars = [
            tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False)
            for tv in t_vars
        ]

        # create a op to initialize all accums vars
        zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars]

        # compute gradients for a batch
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            batch_grads_vars = optimiser.compute_gradients(loss, t_vars)

        # collect the batch gradient into accumulated vars

        accum_ops = [
            accum_tvar.assign_add(batch_grad_var[0]) for accum_tvar,
            batch_grad_var in zip(accum_tvars, batch_grads_vars)
        ]

        accum_normaliser_pl = tf.placeholder(dtype=tf.float32,
                                             name='accum_normaliser')
        accum_mean_op = [
            accum_tvar.assign(tf.divide(accum_tvar, accum_normaliser_pl))
            for accum_tvar in accum_tvars
        ]

        # apply accums gradients
        with tf.control_dependencies(update_ops):
            train_op = optimiser.apply_gradients([
                (accum_tvar, batch_grad_var[1])
                for accum_tvar, batch_grad_var in zip(accum_tvars,
                                                      batch_grads_vars)
            ])

        eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = model_mt.evaluation(
            diag_logits,
            ages_logits,
            diag_placeholder,
            ages_placeholder,
            images_placeholder,
            diag_weight=exp_config.diag_weight,
            age_weight=exp_config.age_weight,
            nlabels=exp_config.nlabels,
            use_ordinal_reg=exp_config.age_ordinal_regression)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver(max_to_keep=3)
        saver_best_diag_f1 = tf.train.Saver(max_to_keep=2)
        saver_best_xent = tf.train.Saver(max_to_keep=2)

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        # Create a session for running Ops on the Graph.
        sess = tf.Session(config=config)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32,
                                    shape=[],
                                    name='val_error_diag')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_diag_f1_score_ = tf.placeholder(tf.float32,
                                            shape=[],
                                            name='val_diag_f1')
        val_f1_diag_summary = tf.summary.scalar('validation_diag_f1',
                                                val_diag_f1_score_)

        val_ages_f1_score_ = tf.placeholder(tf.float32,
                                            shape=[],
                                            name='val_ages_f1')

        val_summary = tf.summary.merge(
            [val_error_summary, val_f1_diag_summary])

        train_error_ = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='train_error_diag')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_diag_f1_score_ = tf.placeholder(tf.float32,
                                              shape=[],
                                              name='train_diag_f1')
        train_diag_f1_summary = tf.summary.scalar('training_diag_f1',
                                                  train_diag_f1_score_)

        train_ages_f1_score_ = tf.placeholder(tf.float32,
                                              shape=[],
                                              name='train_ages_f1')

        train_summary = tf.summary.merge(
            [train_error_summary, train_diag_f1_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = exp_config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_diag_f1_score = 0

        # acum_manual = 0  #np.zeros((2,3,3,3,1,32))

        for epoch in range(exp_config.max_epochs):

            logging.info('EPOCH %d' % epoch)
            sess.run(zero_ops)
            accum_counter = 0

            for batch in iterate_minibatches(
                    images_train, [labels_train, ages_train],
                    batch_size=exp_config.batch_size,
                    augmentation_function=exp_config.augmentation_function,
                    exp_config=exp_config):

                if exp_config.warmup_training:
                    if step < 50:
                        curr_lr = exp_config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = exp_config.learning_rate

                start_time = time.time()

                # get a batch
                x, [y, a] = batch

                # TEMPORARY HACK (to avoid incomplete batches)
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # Run accumulation
                feed_dict = {
                    images_placeholder: x,
                    diag_placeholder: y,
                    ages_placeholder: a,
                    learning_rate_placeholder: curr_lr,
                    training_time_placeholder: True
                }

                _, loss_value = sess.run([accum_ops, loss],
                                         feed_dict=feed_dict)

                accum_counter += 1

                if accum_counter == exp_config.n_accum_batches:

                    # Average gradient over batches
                    sess.run(accum_mean_op,
                             feed_dict={
                                 accum_normaliser_pl:
                                 float(exp_config.n_accum_batches)
                             })
                    sess.run(train_op, feed_dict=feed_dict)

                    # Reset all counters etc.
                    sess.run(zero_ops)
                    accum_counter = 0

                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.

                        logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    if (step + 1) % exp_config.train_eval_frequency == 0:

                        # Evaluate against the training set
                        logging.info('Training Data Eval:')
                        [train_loss, train_diag_f1, train_ages_f1] = do_eval(
                            sess,
                            eval_diag_loss,
                            eval_ages_loss,
                            pred_labels,
                            ages_softmaxs,
                            images_placeholder,
                            diag_placeholder,
                            ages_placeholder,
                            training_time_placeholder,
                            images_train, [labels_train, ages_train],
                            batch_size=exp_config.batch_size,
                            do_ordinal_reg=exp_config.age_ordinal_regression)

                        train_summary_msg = sess.run(train_summary,
                                                     feed_dict={
                                                         train_error_:
                                                         train_loss,
                                                         train_diag_f1_score_:
                                                         train_diag_f1,
                                                         train_ages_f1_score_:
                                                         train_ages_f1
                                                     })
                        summary_writer.add_summary(train_summary_msg, step)

                        loss_history.append(train_loss)
                        if len(loss_history) > 5:
                            loss_history.pop(0)
                            loss_gradient = (loss_history[-5] -
                                             loss_history[-1]) / 2

                        logging.info('loss gradient is currently %f' %
                                     loss_gradient)

                        if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                            logging.warning('Reducing learning rate!')
                            curr_lr /= 10.0
                            logging.info('Learning rate changed to: %f' %
                                         curr_lr)

                            # reset loss history to give the optimisation some time to start decreasing again
                            loss_gradient = np.inf
                            loss_history = []

                        if train_loss <= last_train:  # best_train:
                            logging.info('Decrease in training error!')
                        else:
                            logging.info(
                                'No improvment in training error for %d steps'
                                % no_improvement_counter)

                        last_train = train_loss

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_file, global_step=step)

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')

                        [val_loss, val_diag_f1, val_ages_f1] = do_eval(
                            sess,
                            eval_diag_loss,
                            eval_ages_loss,
                            pred_labels,
                            ages_softmaxs,
                            images_placeholder,
                            diag_placeholder,
                            ages_placeholder,
                            training_time_placeholder,
                            images_val, [labels_val, ages_val],
                            batch_size=exp_config.batch_size,
                            do_ordinal_reg=exp_config.age_ordinal_regression)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_:
                                                       val_loss,
                                                       val_diag_f1_score_:
                                                       val_diag_f1,
                                                       val_ages_f1_score_:
                                                       val_ages_f1
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_diag_f1 >= best_diag_f1_score:
                            best_diag_f1_score = val_diag_f1
                            best_file = os.path.join(
                                log_dir, 'model_best_diag_f1.ckpt')
                            saver_best_diag_f1.save(sess,
                                                    best_file,
                                                    global_step=step)
                            logging.info(
                                'Found new best DIAGNOSIS F1 score on validation set! - %f -  Saving model_best_diag_f1.ckpt'
                                % val_diag_f1)

                        if val_loss <= best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                    step += 1

        sess.close()
def score_data(input_folder,
               output_folder,
               model_path,
               exp_config,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False,
               use_iter=None):

    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')

    mask_pl, softmax_pl = model.predict(images_pl, exp_config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    evaluate_test_set = not gt_exists

    with tf.Session() as sess:

        sess.run(init)

        if not use_iter:
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                model_path, 'model_best_dice.ckpt')
        else:
            checkpoint_path = os.path.join(model_path,
                                           'model.ckpt-%d' % use_iter)

        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0

        for folder in os.listdir(input_folder):

            folder_path = os.path.join(input_folder, folder)

            if os.path.isdir(folder_path):

                if evaluate_test_set or evaluate_all:
                    train_test = 'test'  # always test
                else:
                    train_test = 'test' if (int(folder[-3:]) %
                                            5 == 0) else 'train'

                if train_test == 'test':

                    infos = {}
                    for line in open(os.path.join(folder_path, 'Info.cfg')):
                        label, value = line.split(':')
                        infos[label] = value.rstrip('\n').lstrip(' ')

                    patient_id = folder.lstrip('patient')
                    ED_frame = int(infos['ED'])
                    ES_frame = int(infos['ES'])

                    for file in glob.glob(
                            os.path.join(folder_path,
                                         'patient???_frame??.nii.gz')):

                        logging.info(
                            ' ----- Doing image: -------------------------')
                        logging.info('Doing: %s' % file)
                        logging.info(
                            ' --------------------------------------------')

                        file_base = file.split('.nii.gz')[0]

                        frame = int(file_base.split('frame')[-1])
                        img_dat = utils.load_nii(file)
                        img = img_dat[0].copy()
                        img = image_utils.normalise_image(img)

                        if gt_exists:
                            file_mask = file_base + '_gt.nii.gz'
                            mask_dat = utils.load_nii(file_mask)
                            mask = mask_dat[0]

                        start_time = time.time()

                        if exp_config.data_mode == '2D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2])
                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1])

                            predictions = []

                            for zz in range(img.shape[2]):

                                slice_img = np.squeeze(img[:, :, zz])
                                slice_rescaled = transform.rescale(
                                    slice_img,
                                    scale_vector,
                                    order=1,
                                    preserve_range=True,
                                    multichannel=False,
                                    mode='constant')

                                x, y = slice_rescaled.shape

                                x_s = (x - nx) // 2
                                y_s = (y - ny) // 2
                                x_c = (nx - x) // 2
                                y_c = (ny - y) // 2

                                # Crop section of image for prediction
                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]
                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                # GET PREDICTION
                                network_input = np.float32(
                                    np.tile(
                                        np.reshape(slice_cropped, (nx, ny, 1)),
                                        (batch_size, 1, 1, 1)))
                                mask_out, logits_out = sess.run(
                                    [mask_pl, softmax_pl],
                                    feed_dict={images_pl: network_input})
                                prediction_cropped = np.squeeze(
                                    logits_out[0, ...])

                                # ASSEMBLE BACK THE SLICES
                                slice_predictions = np.zeros(
                                    (x, y, num_channels))
                                # insert cropped region into original image again
                                if x > nx and y > ny:
                                    slice_predictions[
                                        x_s:x_s + nx,
                                        y_s:y_s + ny, :] = prediction_cropped
                                else:
                                    if x <= nx and y > ny:
                                        slice_predictions[:, y_s:y_s +
                                                          ny, :] = prediction_cropped[
                                                              x_c:x_c +
                                                              x, :, :]
                                    elif x > nx and y <= ny:
                                        slice_predictions[
                                            x_s:x_s +
                                            nx, :, :] = prediction_cropped[:,
                                                                           y_c:
                                                                           y_c +
                                                                           y, :]
                                    else:
                                        slice_predictions[:, :, :] = prediction_cropped[
                                            x_c:x_c + x, y_c:y_c + y, :]

                                # RESCALING ON THE LOGITS
                                if gt_exists:
                                    prediction = transform.resize(
                                        slice_predictions,
                                        (mask.shape[0], mask.shape[1],
                                         num_channels),
                                        order=1,
                                        preserve_range=True,
                                        mode='constant')
                                else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                                    # we use the gt mask size for resizing.
                                    prediction = transform.rescale(
                                        slice_predictions,
                                        (1.0 / scale_vector[0],
                                         1.0 / scale_vector[1], 1),
                                        order=1,
                                        preserve_range=True,
                                        multichannel=False,
                                        mode='constant')

                                # prediction = transform.resize(slice_predictions,
                                #                               (mask.shape[0], mask.shape[1], num_channels),
                                #                               order=1,
                                #                               preserve_range=True,
                                #                               mode='constant')

                                prediction = np.uint8(
                                    np.argmax(prediction, axis=-1))
                                predictions.append(prediction)

                            prediction_arr = np.transpose(
                                np.asarray(predictions, dtype=np.uint8),
                                (1, 2, 0))

                        elif exp_config.data_mode == '3D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2],
                                          img_dat[2].structarr['pixdim'][3])

                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1],
                                            pixel_size[2] /
                                            exp_config.target_resolution[2])

                            vol_scaled = transform.rescale(img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                            nz_max = exp_config.image_size[2]
                            slice_vol = np.zeros((nx, ny, nz_max),
                                                 dtype=np.float32)

                            nz_curr = vol_scaled.shape[2]
                            stack_from = (nz_max - nz_curr) // 2
                            stack_counter = stack_from

                            x, y, z = vol_scaled.shape

                            x_s = (x - nx) // 2
                            y_s = (y - ny) // 2
                            x_c = (nx - x) // 2
                            y_c = (ny - y) // 2

                            for zz in range(nz_curr):

                                slice_rescaled = vol_scaled[:, :, zz]

                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]

                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                slice_vol[:, :, stack_counter] = slice_cropped
                                stack_counter += 1

                            stack_to = stack_counter

                            network_input = np.float32(
                                np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))

                            start_time = time.time()
                            mask_out, logits_out = sess.run(
                                [mask_pl, softmax_pl],
                                feed_dict={images_pl: network_input})

                            logging.info('Classified 3D: %f secs' %
                                         (time.time() - start_time))

                            prediction_nzs = logits_out[0, :, :,
                                                        stack_from:stack_to,
                                                        ...]  # non-zero-slices

                            if not prediction_nzs.shape[2] == nz_curr:
                                raise ValueError('sizes mismatch')

                            # ASSEMBLE BACK THE SLICES
                            prediction_scaled = np.zeros(
                                list(vol_scaled.shape) +
                                [num_channels
                                 ])  # last dim is for logits classes

                            # insert cropped region into original image again
                            if x > nx and y > ny:
                                prediction_scaled[x_s:x_s + nx,
                                                  y_s:y_s + ny, :,
                                                  ...] = prediction_nzs
                            else:
                                if x <= nx and y > ny:
                                    prediction_scaled[:, y_s:y_s + ny, :,
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x, :, :,
                                                          ...]
                                elif x > nx and y <= ny:
                                    prediction_scaled[
                                        x_s:x_s +
                                        nx, :, :...] = prediction_nzs[:,
                                                                      y_c:y_c +
                                                                      y, :...]
                                else:
                                    prediction_scaled[:, :, :
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x,
                                                          y_c:y_c + y, :...]

                            logging.info('Prediction_scaled mean %f' %
                                         (np.mean(prediction_scaled)))

                            prediction = transform.resize(
                                prediction_scaled,
                                (mask.shape[0], mask.shape[1], mask.shape[2],
                                 num_channels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                            prediction = np.argmax(prediction, axis=-1)
                            prediction_arr = np.asarray(prediction,
                                                        dtype=np.uint8)

                        # This is the same for 2D and 3D again
                        if do_postprocessing:
                            prediction_arr = image_utils.keep_largest_connected_components(
                                prediction_arr)

                        elapsed_time = time.time() - start_time
                        total_time += elapsed_time
                        total_volumes += 1

                        logging.info('Evaluation of volume took %f secs.' %
                                     elapsed_time)

                        if frame == ED_frame:
                            frame_suffix = '_ED'
                        elif frame == ES_frame:
                            frame_suffix = '_ES'
                        else:
                            raise ValueError(
                                'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                                % (frame, ED_frame, ES_frame))

                        # Save prediced mask
                        out_file_name = os.path.join(
                            output_folder, 'prediction',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        if gt_exists:
                            out_affine = mask_dat[1]
                            out_header = mask_dat[2]
                        else:
                            out_affine = img_dat[1]
                            out_header = img_dat[2]

                        logging.info('saving to: %s' % out_file_name)
                        utils.save_nii(out_file_name, prediction_arr,
                                       out_affine, out_header)

                        # Save image data to the same folder for convenience
                        image_file_name = os.path.join(
                            output_folder, 'image',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % image_file_name)
                        utils.save_nii(image_file_name, img_dat[0], out_affine,
                                       out_header)

                        if gt_exists:

                            # Save GT image
                            gt_file_name = os.path.join(
                                output_folder, 'ground_truth', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % gt_file_name)
                            utils.save_nii(gt_file_name, mask, out_affine,
                                           out_header)

                            # Save difference mask between predictions and ground truth
                            difference_mask = np.where(
                                np.abs(prediction_arr - mask) > 0, [1], [0])
                            difference_mask = np.asarray(difference_mask,
                                                         dtype=np.uint8)
                            diff_file_name = os.path.join(
                                output_folder, 'difference', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % diff_file_name)
                            utils.save_nii(diff_file_name, difference_mask,
                                           out_affine, out_header)

        logging.info('Average time per volume: %f' %
                     (total_time / total_volumes))

    return init_iteration
Пример #30
0
def main(config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=config.data_root ,                  #data_root    test_data_root
        preprocessing_folder=config.preprocessing_folder,
        mode=config.data_mode,
        size=config.image_size,
        target_resolution=config.target_resolution,
        force_overwrite=False
    )

    batch_size = 1

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()


    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        while True:

            ind = np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind,...]
            y = data['masks_test'][ind,...]
            
            for img in x:
                if config.standardize:
                    img = image_utils.standardize_image(img)
                if config.normalize:
                    img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)

            #x = cv2.normalize(x, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)
            logging.info('x')
            logging.info(x.shape)
            logging.info(x.dtype)
            logging.info(x.min())
            logging.info(x.max())
            plt.imshow(np.squeeze(x))
            plt.gray()
            plt.axis('off')
            plt.show()
            logging.info('y')
            logging.info(y.shape)
            logging.info(y.dtype)
            
            feed_dict = {
                images_pl: x,
            }

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)
            logging.info('mask_out')
            logging.info(mask_out.shape)
            logging.info('softmax_out')
            logging.info(softmax_out.shape)
            fig = plt.figure(1)
            ax1 = fig.add_subplot(241)
            ax1.set_axis_off()
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(242)
            ax2.set_axis_off()
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(243)
            ax3.set_axis_off()
            ax3.imshow(np.squeeze(mask_out))
            ax1.title.set_text('a')
            ax2.title.set_text('b')
            ax3.title.set_text('c')
           
            ax5 = fig.add_subplot(245)
            ax5.set_axis_off()
            ax5.imshow(np.squeeze(softmax_out[...,0]))
            ax6 = fig.add_subplot(246)
            ax6.set_axis_off()
            ax6.imshow(np.squeeze(softmax_out[...,1]))
            ax7 = fig.add_subplot(247)
            ax7.set_axis_off()
            ax7.imshow(np.squeeze(softmax_out[...,2]))
            ax8 = fig.add_subplot(248)
            ax8.set_axis_off()
            ax8.imshow(np.squeeze(softmax_out[...,3]))  #cmap=cm.Blues
            ax5.title.set_text('d')
            ax6.title.set_text('e')
            ax7.title.set_text('f')
            ax8.title.set_text('g')
            plt.gray()
            plt.show()
            
            logging.info('mask_out type')
            logging.info(mask_out.dtype)

    data.close()