예제 #1
0
def fit_one_cycle(epochs,
                  max_lr,
                  model,
                  train_loader,
                  val_loader,
                  weight_decay=0,
                  grad_clip=None,
                  opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []

    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader))

    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()

            # Gradient clipping
            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()

        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history
예제 #2
0
def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl,
                     labels_sd_vl, images_td_tr, images_td_vl):

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_num_uda)
        np.random.seed(exp_config.run_num_uda)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_sd_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_sd')
        images_td_pl = tf.placeholder(tf.float32,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size) + [1],
                                      name='images_td')
        labels_sd_pl = tf.placeholder(tf.uint8,
                                      shape=[exp_config.batch_size] +
                                      list(exp_config.image_size),
                                      name='labels_sd')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # ================================================================
        images_sd_normalized, _ = model.normalize(images_sd_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=False)
        images_td_normalized, _ = model.normalize(images_td_pl,
                                                  exp_config,
                                                  training_pl,
                                                  scope_reuse=True)

        # ================================================================
        # get logit predictions from the segmentation network
        # ================================================================
        predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized,
                                                          exp_config,
                                                          training_pl,
                                                          scope_reuse=False)

        # ================================================================
        # get all features from the segmentation network
        # ================================================================
        seg_features_sd = model.get_all_features(images_sd_normalized,
                                                 exp_config,
                                                 scope_reuse=True)
        seg_features_td = model.get_all_features(images_td_normalized,
                                                 exp_config,
                                                 scope_reuse=True)

        # ================================================================
        # resize all features to the same size
        # ================================================================
        images_sd_features_resized = model.resize_features(
            [images_sd_normalized] + list(seg_features_sd), (64, 64),
            'resize_sd_xfeat')
        images_td_features_resized = model.resize_features(
            [images_td_normalized] + list(seg_features_td), (64, 64),
            'resize_td_xfeat')

        # ================================================================
        # discriminator on features
        # ================================================================
        d_logits_sd = model.discriminator(images_sd_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=False)
        d_logits_td = model.discriminator(images_td_features_resized,
                                          exp_config,
                                          training_pl,
                                          scope_name='discriminator',
                                          scope_reuse=True)

        # ================================================================
        # add ops for calculation of the discriminator loss
        # ================================================================
        d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(d_logits_sd), logits=d_logits_sd)
        d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(d_logits_td), logits=d_logits_td)
        loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td)
        tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op)

        # ================================================================
        # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space
        # ================================================================
        loss_g_op = model.loss_invariance(d_logits_td)
        tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op)

        # ================================================================
        # add ops for calculation of the supervised segmentation loss
        # ================================================================
        loss_seg_op = model.loss(predicted_seg_sd_logits,
                                 labels_sd_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op)

        # ================================================================
        # total training loss for uda
        # ================================================================
        loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op
        tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op)

        # ================================================================
        # merge all summaries
        # ================================================================
        summary_scalars = tf.summary.merge_all()

        # ================================================================
        # divide the vars into segmentation network, normalization network and the discriminator network
        # ================================================================
        i2l_vars = []
        normalization_vars = []
        discriminator_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'discriminator' in var_name:
                discriminator_vars.append(v)

        # ================================================================
        # add optimization ops
        # ================================================================
        train_i2l_op = model.training_step(
            loss_total_op,
            i2l_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        train_discriminator_op = model.training_step(
            loss_d_op,
            discriminator_vars,
            exp_config.optimizer_handle,
            learning_rate=exp_config.learning_rate)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        eval_loss = model.evaluation_i2l_uda_invariant_features(
            predicted_seg_sd_logits,
            labels_sd_pl,
            images_sd_pl,
            d_logits_td,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # add ops for adding image summary to tensorboard
        # ================================================================
        summary_images = model.write_image_summary_uda_invariant_features(
            predicted_seg_sd_logits, labels_sd_pl, images_sd_pl,
            exp_config.nlabels)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a file writer object
        # This writes Summary protocol buffers to event files.
        # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md
        # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it.
        # The class updates the file contents asynchronously.
        # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # create savers
        # ================================================================
        saver = tf.train.Saver(var_list=i2l_vars)
        saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='vl_error_seg')
        vl_error_seg_summary = tf.summary.scalar('validation/loss_seg',
                                                 vl_error_seg)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='vl_error_invariance')
        vl_error_invariance_summary = tf.summary.scalar(
            'validation/loss_invariance', vl_error_invariance)
        vl_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='vl_error_total')
        vl_error_total_summary = tf.summary.scalar('validation/loss_total',
                                                   vl_error_total)
        vl_summary = tf.summary.merge([
            vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary,
            vl_error_total_summary
        ])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error_seg = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='tr_error_seg')
        tr_error_seg_summary = tf.summary.scalar('training/loss_seg',
                                                 tr_error_seg)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_error_invariance = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='tr_error_invariance')
        tr_error_invariance_summary = tf.summary.scalar(
            'training/loss_invariance', tr_error_invariance)
        tr_error_total = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='tr_error_total')
        tr_error_total_summary = tf.summary.scalar('training/loss_total',
                                                   tr_error_total)
        tr_summary = tf.summary.merge([
            tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary,
            tr_error_total_summary
        ])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # ================================================================
        if exp_config.train_from_scratch is False:
            logging.info(
                '============================================================')
            path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_dice.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_lowest_loss.restore(sess, checkpoint_path)

        # ================================================================
        # run training steps
        # ================================================================
        step = 0
        lowest_loss = 10000.0
        validation_total_loss_list = []

        while (step < exp_config.max_steps):

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(images_sd=images_sd_tr,
                                             labels_sd=labels_sd_tr,
                                             images_td=images_td_tr,
                                             batch_size=exp_config.batch_size):

                x_sd, y_sd, x_td = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_sd_pl: x_sd,
                    labels_sd_pl: y_sd,
                    images_td_pl: x_td,
                    training_pl: True
                }

                # ================================================
                # update i2l and D successively
                # ================================================
                sess.run(train_i2l_op, feed_dict=feed_dict)
                sess.run(train_discriminator_op, feed_dict=feed_dict)

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        '============== Updating summary at step %d ' % step)
                    summary_writer.add_summary(
                        sess.run(summary_scalars, feed_dict=feed_dict), step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('============== Training Data Eval:')
                    train_loss_seg, train_dice, train_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_tr, labels_sd_tr,
                        images_td_tr, exp_config.batch_size)

                    # total training loss
                    train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error_seg: train_loss_seg,
                                                  tr_dice: train_dice,
                                                  tr_error_invariance:
                                                  train_loss_invariance,
                                                  tr_error_total:
                                                  train_total_loss
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    logging.info(
                        '============== Periodically saving checkpoint:')
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('============== Validation Data Eval:')
                    val_loss_seg, val_dice, val_loss_invariance = do_eval(
                        sess, eval_loss, images_sd_pl, labels_sd_pl,
                        images_td_pl, training_pl, images_sd_vl, labels_sd_vl,
                        images_td_vl, exp_config.batch_size)

                    # total val loss
                    val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance
                    validation_total_loss_list.append(val_total_loss)

                    # ===========================
                    # update tensorboard summary of scalars
                    # ===========================
                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error_seg: val_loss_seg,
                                                  vl_dice: val_dice,
                                                  vl_error_invariance:
                                                  val_loss_invariance,
                                                  vl_error_total:
                                                  val_total_loss
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # update tensorboard summary of images
                    # ===========================
                    summary_writer.add_summary(
                        sess.run(summary_images,
                                 feed_dict={
                                     images_sd_pl: x_sd,
                                     labels_sd_pl: y_sd,
                                     training_pl: False
                                 }), step)
                    summary_writer.flush()

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    window_length = 5
                    if len(validation_total_loss_list) < window_length + 1:
                        expo_moving_avg_loss_value = validation_total_loss_list[
                            -1]
                    else:
                        expo_moving_avg_loss_value = utils.exponential_moving_average(
                            validation_total_loss_list,
                            window=window_length)[-1]

                    if expo_moving_avg_loss_value < lowest_loss:
                        lowest_loss = val_total_loss
                        lowest_loss_file = os.path.join(
                            log_dir, 'models/lowest_loss.ckpt')
                        saver_lowest_loss.save(sess,
                                               lowest_loss_file,
                                               global_step=step)
                        logging.info(
                            '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********'
                            % step)

                # ================================================
                # increment step
                # ================================================
                step += 1

        # ================================================================
        # close tf session
        # ================================================================
        sess.close()

    return 0
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    logging.info('Reading HCP - 3T - T1 images...')
    logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
    imtr, gttr = data_hcp.load_data(sys_config.orig_data_root_hcp,
                                    sys_config.preproc_folder_hcp, 'T1', 1,
                                    exp_config.train_size + 1)
    imvl, gtvl = data_hcp.load_data(
        sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1',
        exp_config.train_size + 1,
        exp_config.train_size + exp_config.val_size + 1)
    logging.info(
        'Training Images: %s' %
        str(imtr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info(
        'Training Labels: %s' %
        str(gttr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info('Validation Images: %s' % str(imvl.shape))
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # ================================================================
    # Define the segmentation network here
    # ================================================================
    mask = ~(imtr == 0).all(axis=(1, 2))
    imtr = imtr[mask]
    gttr = gttr[mask]

    mask = ~(imvl == 0).all(axis=(1, 2))
    imvl = imvl[mask]
    gtvl = gtvl[mask]

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        # Placeholders for the images and labels
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        images_pl = tf.compat.v1.placeholder(tf.float32,
                                             shape=image_tensor_shape,
                                             name='images')
        labels_tensor_shape_segmentation = [exp_config.batch_size] + list(
            exp_config.image_size)
        labels_tensor_shape_self_supervised = [exp_config.batch_size] + [
            exp_config.num_rotation_values
        ]
        labels_segmentation_pl = tf.compat.v1.placeholder(
            tf.uint8,
            shape=labels_tensor_shape_segmentation,
            name='labels_segmentation')
        labels_self_supervised_pl = tf.compat.v1.placeholder(
            tf.float16,
            shape=labels_tensor_shape_self_supervised,
            name='labels_self_supervised')
        # Placeholders for the learning rate and to indicate whether the model is being trained or tested
        learning_rate_pl = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[],
                                                    name='learning_rate')
        training_pl = tf.compat.v1.placeholder(tf.bool,
                                               shape=[],
                                               name='training_or_testing')

        # ================================================================
        # Define the segmentation network here
        # ================================================================
        logits_segmentation = exp_config.model_handle_segmentation(
            images_pl, image_tensor_shape, training_pl, exp_config.nlabels)

        # ================================================================
        # determine trainable variables
        # ================================================================
        segmentation_vars = []
        self_supervised_vars = []
        test_time_opt_vars = []

        for v in tf.compat.v1.trainable_variables():
            var_name = v.name
            if 'rotation' in var_name:
                self_supervised_vars.append(v)
            elif 'segmentation' in var_name:
                segmentation_vars.append(v)
            elif 'normalization' in var_name:
                test_time_opt_vars.append(v)

        if exp_config.debug is True:
            logging.info('================================')
            logging.info('List of trainable variables in the graph:')
            for v in tf.compat.v1.trainable_variables():
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all segmentation variables:')
            for v in segmentation_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all self-supervised variables:')
            for v in self_supervised_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all test time variables:')
            for v in test_time_opt_vars:
                logging.info(v.name)

        # ================================================================
        # Add ops for calculation of the training loss
        # ================================================================
        loss_segmentation = model.loss(logits_segmentation,
                                       labels_segmentation_pl,
                                       nlabels=exp_config.nlabels,
                                       loss_type=exp_config.loss_type)
        tf.compat.v1.summary.scalar('loss_segmentation', loss_segmentation)

        # ================================================================
        # Add optimization ops.
        # ================================================================
        train_op_train_time = model.training_step(
            loss_segmentation, segmentation_vars,
            exp_config.optimizer_handle_training, learning_rate_pl)

        # ================================================================
        # Add ops for model evaluation
        # ================================================================
        eval_loss_segmentation = model.evaluation(
            logits_segmentation,
            labels_segmentation_pl,
            images_pl,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type)

        # ================================================================
        # Build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        summary = tf.compat.v1.summary.merge_all()

        # ================================================================
        # Add init ops
        # ================================================================
        init_g = tf.compat.v1.global_variables_initializer()
        init_l = tf.compat.v1.local_variables_initializer()

        # ================================================================
        # Find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.compat.v1.report_uninitialized_variables()

        # ================================================================
        # create savers for each domain
        # ================================================================
        max_to_keep = 15
        saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep)
        saver_best_da = tf.compat.v1.train.Saver()

        # ================================================================
        # Create session
        # ================================================================
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.compat.v1.Session(config=config)

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='tr_error_seg')
        tr_error_summary_seg = tf.compat.v1.summary.scalar(
            'training/loss_seg', tr_error_seg)
        tr_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='tr_dice_seg')
        tr_dice_summary_seg = tf.compat.v1.summary.scalar(
            'training/dice_seg', tr_dice_seg)
        tr_summary_seg = tf.compat.v1.summary.merge(
            [tr_error_summary_seg, tr_dice_summary_seg])

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='vl_error_seg')
        vl_error_summary_seg = tf.compat.v1.summary.scalar(
            'validation/loss_seg', vl_error_seg)
        vl_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='vl_dice_seg')
        vl_dice_summary_seg = tf.compat.v1.summary.scalar(
            'validation/dice_seg', vl_dice_seg)
        vl_summary_seg = tf.compat.v1.summary.merge(
            [vl_error_summary_seg, vl_dice_summary_seg])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.compat.v1.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_g)
        sess.run(init_l)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            logging.info(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_da = 0

        # ================================================================
        # run training epochs
        # ================================================================
        for epoch in range(exp_config.max_epochs):

            logging.info(
                '============================================================')
            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(imtr,
                                             gttr,
                                             batch_size=exp_config.batch_size):

                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y_seg, y_ss = batch

                if step == 0:
                    x_s = x
                    y_seg_s = y_seg
                    y_ss_s = y_ss

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y_seg.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_segmentation_pl: y_seg,
                    labels_self_supervised_pl: y_ss,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # update vars
                # ===========================
                _, loss_value = sess.run(
                    [train_op_train_time, loss_segmentation],
                    feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.2f (%.3f sec for the last step)' %
                        (step + 1, loss_value, duration))

                    # ===========================
                    # print values of a parameter (to debug)
                    # ===========================
                    if exp_config.debug is True:
                        var_value = tf.compat.v1.get_collection(
                            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES
                        )[0].eval(
                            session=sess
                        )  # can add name if you only want parameters with a specific name
                        logging.info('value of one of the parameters %f' %
                                     var_value[0, 0, 0, 0])

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss_seg, train_dice_seg
                     ] = do_eval(sess, eval_loss_segmentation, images_pl,
                                 labels_segmentation_pl, training_pl, imtr,
                                 gttr, exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary_seg,
                                              feed_dict={
                                                  tr_error_seg: train_loss_seg,
                                                  tr_dice_seg: train_dice_seg
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:

                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically
                # ===========================
                if step % exp_config.val_eval_frequency == 0:

                    # ===========================
                    # Evaluate against the validation set of each domain
                    # ===========================
                    logging.info('Validation Data Eval:')
                    [val_loss_seg,
                     val_dice_seg] = do_eval(sess, eval_loss_segmentation,
                                             images_pl, labels_segmentation_pl,
                                             training_pl, imvl, gtvl,
                                             exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary_seg,
                                              feed_dict={
                                                  vl_error_seg: val_loss_seg,
                                                  vl_dice_seg: val_dice_seg
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice/accuracy is the best yet
                    # ===========================
                    if val_dice_seg > best_da:
                        best_da = val_dice_seg
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_da.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice_seg)

                step += 1

        sess.close()
예제 #4
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    if hasattr(exp_config, 'train_on_all_data'):
        train_on_all_data = exp_config.train_on_all_data
    else:
        train_on_all_data = False

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=(not train_on_all_data))

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']

    if not train_on_all_data:
        images_val = data['images_test']
        labels_val = data['masks_test']

    if exp_config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * exp_config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        logits = model.inference(images_pl, exp_config, training=training_pl)

        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=exp_config.nlabels,
                                    loss_type=exp_config.loss_type,
                                    weight_decay=exp_config.weight_decay
                                    )  # second output is unregularised loss

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if exp_config.momentum is not None:
            train_op = model.training_step(loss,
                                           exp_config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=exp_config.momentum)
        else:
            train_op = model.training_step(loss, exp_config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=exp_config.nlabels,
                                     loss_type=exp_config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=config)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = exp_config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(exp_config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(
                    images_train,
                    labels_train,
                    batch_size=exp_config.batch_size,
                    augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning eventhough the code has
                # crashed.
                # for batch in BackgroundGenerator(iterate_minibatches(images_train,
                #                                                      labels_train,
                #                                                      batch_size=exp_config.batch_size,
                #                                                      augment_batch=exp_config.augment_batch)):

                if exp_config.warmup_training:
                    if step < 50:
                        curr_lr = exp_config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = exp_config.learning_rate

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 10 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if (step + 1) % exp_config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl,
                                           images_train, labels_train,
                                           exp_config.batch_size)

                    train_summary_msg = sess.run(train_summary,
                                                 feed_dict={
                                                     train_error_: train_loss,
                                                     train_dice_: train_dice
                                                 })
                    summary_writer.add_summary(train_summary_msg, step)

                    loss_history.append(train_loss)
                    if len(loss_history) > 5:
                        loss_history.pop(0)
                        loss_gradient = (loss_history[-5] -
                                         loss_history[-1]) / 2

                    logging.info('loss gradient is currently %f' %
                                 loss_gradient)

                    if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                        logging.warning('Reducing learning rate!')
                        curr_lr /= 10.0
                        logging.info('Learning rate changed to: %f' % curr_lr)

                        # reset loss history to give the optimisation some time to start decreasing again
                        loss_gradient = np.inf
                        loss_history = []

                    if train_loss <= last_train:  # best_train:
                        logging.info('Decrease in training error!')
                    else:
                        logging.info(
                            'No improvment in training error for %d steps' %
                            no_improvement_counter)

                    last_train = train_loss

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % exp_config.val_eval_frequency == 0:

                    checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)
                    # Evaluate against the training set.

                    if not train_on_all_data:

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss,
                         val_dice] = do_eval(sess, eval_loss, images_pl,
                                             labels_pl, training_pl,
                                             images_val, labels_val,
                                             exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(log_dir,
                                                     'model_best_dice.ckpt')
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                                % val_dice)

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                step += 1

        sess.close()
    data.close()
예제 #5
0
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    if exp_config.train_dataset is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=1040,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    if exp_config.train_dataset is 'HCPT2':
        logging.info('Reading HCPT2 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=20,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_hcp,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'CALTECH':
        logging.info('Reading CALTECH images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'CALTECH/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_caltech,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'STANFORD':
        logging.info('Reading STANFORD images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'STANFORD/')
        data_brain_train = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=0,
            idx_end=10,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=10,
            idx_end=15,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth_stanford,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    elif exp_config.train_dataset is 'IXI':
        logging.info('Reading IXI images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_ixi)
        data_brain_train = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=0,
            idx_end=12,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imtr, gttr = [data_brain_train['images'], data_brain_train['labels']]

        data_brain_val = data_ixi.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_ixi,
            preprocessing_folder=sys_config.preproc_folder_ixi,
            idx_start=12,
            idx_end=17,
            protocol='T2',
            size=exp_config.image_size,
            depth=exp_config.image_depth_ixi,
            target_resolution=exp_config.target_resolution_brain)
        imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']]

    logging.info(
        'Training Images: %s' %
        str(imtr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info(
        'Training Labels: %s' %
        str(gttr.shape))  # expected: [num_slices, img_size_x, img_size_y]
    logging.info('Validation Images: %s' % str(imvl.shape))
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)
        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module will be adapted for each test image
        # ================================================================
        images_normalized, _ = model.normalize(images_pl, exp_config,
                                               training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        logits, _, _ = model.predict_i2l(images_normalized,
                                         exp_config,
                                         training_pl=training_pl)

        print('shape of inputs: ',
              images_pl.shape)  # (batch_size, 256, 256, 1)
        print('shape of logits: ',
              logits.shape)  # (batch_size, 256, 256, nlabels)

        # ================================================================
        # create a list of all vars that must be optimized wrt
        # ================================================================
        i2l_vars = []
        for v in tf.trainable_variables():
            i2l_vars.append(v)

        # ================================================================
        # add ops for calculation of the supervised training loss
        # ================================================================
        loss_op = model.loss(logits,
                             labels_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_i2l)
        tf.summary.scalar('loss', loss_op)

        # ================================================================
        # add optimization ops.
        # Create different ops according to the variables that must be trained
        # ================================================================
        print('creating training op...')
        train_op = model.training_step(loss_op,
                                       i2l_vars,
                                       exp_config.optimizer_handle,
                                       learning_rate_pl,
                                       update_bn_nontrainable_vars=True)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        print('creating eval op...')
        eval_loss = model.evaluation_i2l(logits,
                                         labels_pl,
                                         images_pl,
                                         nlabels=exp_config.nlabels,
                                         loss_type=exp_config.loss_type_i2l)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        print('creating summary op...')
        summary = tf.summary.merge_all()

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create saver
        # ================================================================
        saver = tf.train.Saver(max_to_keep=10)
        saver_best_dice = tf.train.Saver(max_to_keep=3)

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error')
        vl_error_summary = tf.summary.scalar('validation/loss', vl_error)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error')
        tr_error_summary = tf.summary.scalar('training/loss', tr_error)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of all variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of all variables:')
        for v in tf.trainable_variables():
            print(v.name)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            print(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_dice = 0

        # ================================================================
        # run training epochs
        # ================================================================
        while (step < exp_config.max_steps):

            if step % 1000 is 0:
                logging.info(
                    '============================================================'
                )
                logging.info('step %d' % step)

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(imtr,
                                             gttr,
                                             batch_size=exp_config.batch_size,
                                             train_or_eval='train'):

                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # ===========================
                # create the feed dict for this training iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # opt step
                # ===========================
                _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.3f (%.3f sec for the last step)' %
                        (step + 1, loss, duration))

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('Training Data Eval:')
                    train_loss, train_dice = do_eval(sess, eval_loss,
                                                     images_pl, labels_pl,
                                                     training_pl, imtr, gttr,
                                                     exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error: train_loss,
                                                  tr_dice: train_dice
                                              })

                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('Validation Data Eval:')
                    val_loss, val_dice = do_eval(sess, eval_loss, images_pl,
                                                 labels_pl, training_pl, imvl,
                                                 gtvl, exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error: val_loss,
                                                  vl_dice: val_dice
                                              })

                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    if val_dice > best_dice:
                        best_dice = val_dice
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_dice.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice)

                step += 1

        sess.close()
def run_training():

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # Determine the data set
    # ============================
    target_data = True

    # ============================
    # Load data
    # ============================

    if target_data:
        hcp = True
        if hcp:
            logging.info(
                '============================================================')
            logging.info('Loading data...')
            logging.info('Reading HCP - 3T - T2 images...')
            logging.info('Data root directory: ' +
                         sys_config.orig_data_root_hcp)
            imts, gtts = data_hcp.load_data(sys_config.orig_data_root_hcp,
                                            sys_config.preproc_folder_hcp,
                                            'T2', 1, exp_config.test_size + 1)
            logging.info('Test Images: %s' % str(
                imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info('Test Labels: %s' % str(
                gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info(
                '============================================================')
        else:
            logging.info(
                '============================================================')
            logging.info('Loading data...')
            logging.info('Reading ABIDE caltech...')
            logging.info('Data root directory: ' +
                         sys_config.orig_data_root_abide)
            imts, gtts = data_abide.load_data(sys_config.orig_data_root_abide,
                                              sys_config.preproc_folder_abide,
                                              1, exp_config.test_size + 1)
            logging.info('Test Images: %s' % str(
                imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info('Test Labels: %s' % str(
                gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
            logging.info(
                '============================================================')
    else:
        logging.info(
            '============================================================')
        logging.info('Loading data...')
        logging.info('Reading HCP - 3T - T1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        imts, gtts = data_hcp.load_data(
            sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1',
            exp_config.train_size + exp_config.val_size + 1,
            exp_config.train_size + exp_config.val_size +
            exp_config.test_size + 1)
        logging.info(
            'Test Images: %s' %
            str(imts.shape))  # expected: [num_slices, img_size_x, img_size_y]
        logging.info(
            'Test Labels: %s' %
            str(gtts.shape))  # expected: [num_slices, img_size_x, img_size_y]
        logging.info(
            '============================================================')

    # ============================
    # Remove exclusively black images
    # ============================
    mask = ~(imts == 0).all(axis=(1, 2))
    imts = imts[mask]
    gtts = gtts[mask]

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        # Placeholders for the images and labels
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        images_pl = tf.compat.v1.placeholder(tf.float32,
                                             shape=image_tensor_shape,
                                             name='images')
        labels_tensor_shape_segmentation = [exp_config.batch_size] + list(
            exp_config.image_size)
        labels_tensor_shape_self_supervised = [exp_config.batch_size] + [
            exp_config.num_rotation_values
        ]
        labels_segmentation_pl = tf.compat.v1.placeholder(
            tf.uint8,
            shape=labels_tensor_shape_segmentation,
            name='labels_segmentation')
        labels_self_supervised_pl = tf.compat.v1.placeholder(
            tf.float16,
            shape=labels_tensor_shape_self_supervised,
            name='labels_self_supervised')
        # Placeholders for the learning rate and to indicate whether the model is being trained or tested
        learning_rate_pl = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[],
                                                    name='learning_rate')
        training_pl = tf.compat.v1.placeholder(tf.bool,
                                               shape=[],
                                               name='training_or_testing')

        # ================================================================
        # Define the image normalization function - these parameters will be updated for each test image
        # ================================================================
        image_normalized = exp_config.model_handle_normalizer(
            images_pl, image_tensor_shape)

        # ================================================================
        # Define the segmentation network here
        # ================================================================
        logits_segmentation = exp_config.model_handle_segmentation(
            image_normalized, image_tensor_shape, training_pl,
            exp_config.nlabels)

        # ================================================================
        # Define the self-supervised network here
        # ================================================================
        logits_self_supervised = exp_config.model_handle_rotation(
            image_normalized, image_tensor_shape, training_pl)

        # ================================================================
        # determine trainable variables
        # ================================================================
        segmentation_vars = []
        self_supervised_vars = []
        test_time_opt_vars = []

        for v in tf.compat.v1.trainable_variables():
            var_name = v.name
            if 'rotation' in var_name:
                self_supervised_vars.append(v)
            elif 'segmentation' in var_name:
                segmentation_vars.append(v)
            elif 'normalization' in var_name:
                test_time_opt_vars.append(v)

        if exp_config.debug is True:
            logging.info('================================')
            logging.info('List of trainable variables in the graph:')
            for v in tf.compat.v1.trainable_variables():
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all segmentation variables:')
            for v in segmentation_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all self-supervised variables:')
            for v in self_supervised_vars:
                logging.info(v.name)
            logging.info('================================')
            logging.info('List of all test time variables:')
            for v in test_time_opt_vars:
                logging.info(v.name)

        # ================================================================
        # Add ops for calculation of the training loss
        # ================================================================
        loss_self_supervised = tf.reduce_mean(
            exp_config.loss_handle_self_supervised(labels_self_supervised_pl,
                                                   logits_self_supervised))
        tf.compat.v1.summary.scalar('loss_self_supervised',
                                    loss_self_supervised)

        # ================================================================
        # Add optimization ops.
        # ================================================================
        train_op_test_time = model.training_step(
            loss_self_supervised, test_time_opt_vars,
            exp_config.optimizer_handle_test_time, learning_rate_pl)

        # ================================================================
        # Add ops for model evaluation
        # ================================================================
        eval_loss_segmentation = model.evaluation(
            logits_segmentation,
            labels_segmentation_pl,
            images_pl,
            nlabels=exp_config.nlabels,
            loss_type=exp_config.loss_type)
        eval_loss_self_supervised = model.evaluation_rotate(
            logits_self_supervised, labels_self_supervised_pl)

        # ================================================================
        # Build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        summary = tf.compat.v1.summary.merge_all()

        # ================================================================
        # Add init ops
        # ================================================================
        init_g = tf.compat.v1.global_variables_initializer()
        init_l = tf.compat.v1.local_variables_initializer()

        # ================================================================
        # Find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.compat.v1.report_uninitialized_variables()

        # ================================================================
        # create savers for each domain
        # ================================================================
        max_to_keep = 15
        saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep)
        saver_best_da = tf.compat.v1.train.Saver()
        saver_seg = tf.compat.v1.train.Saver(var_list=segmentation_vars)
        saver_ss = tf.compat.v1.train.Saver(var_list=self_supervised_vars)

        # ================================================================
        # Create session
        # ================================================================
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.compat.v1.Session(config=config)

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the test errors
        # ================================================================
        ts_error_seg = tf.compat.v1.placeholder(tf.float32,
                                                shape=[],
                                                name='ts_error_seg')
        ts_error_summary_seg = tf.compat.v1.summary.scalar(
            'test/loss_seg', ts_error_seg)
        ts_dice_seg = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='ts_dice_seg')
        ts_dice_summary_seg = tf.compat.v1.summary.scalar(
            'test/dice_seg', ts_dice_seg)
        ts_summary_seg = tf.compat.v1.summary.merge(
            [ts_error_summary_seg, ts_dice_summary_seg])

        ts_error_ss = tf.compat.v1.placeholder(tf.float32,
                                               shape=[],
                                               name='ts_error_ss')
        ts_error_summary_ss = tf.compat.v1.summary.scalar(
            'test/loss_ss', ts_error_ss)
        ts_acc_ss = tf.compat.v1.placeholder(tf.float32,
                                             shape=[],
                                             name='ts_acc_ss')
        ts_acc_summary_ss = tf.compat.v1.summary.scalar(
            'test/acc_ss', ts_acc_ss)
        ts_summary_ss = tf.compat.v1.summary.merge(
            [ts_error_summary_ss, ts_acc_summary_ss])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.compat.v1.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_g)
        sess.run(init_l)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            logging.info(v)

        # ================================================================
        # restore shared weights
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('Restore segmentation...')
        model_path = os.path.join(sys_config.log_root,
                                  'Initial_training_segmentation')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'models/best_dice.ckpt')
        logging.info('Restroring session from: %s' % checkpoint_path)
        saver_seg.restore(sess, checkpoint_path)

        logging.info(
            '============================================================')
        logging.info('Restore self-supervised...')
        model_path = os.path.join(sys_config.log_root,
                                  'Initial_training_rotation')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'models/best_dice.ckpt')
        logging.info('Restroring session from: %s' % checkpoint_path)
        saver_ss.restore(sess, checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        curr_lr = exp_config.learning_rate
        best_da = 0

        # ================================================================
        # run training epochs
        # ================================================================
        for epoch in range(exp_config.max_epochs):

            logging.info(
                '============================================================')
            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(imts,
                                             gtts,
                                             batch_size=exp_config.batch_size):
                curr_lr = exp_config.learning_rate
                start_time = time.time()
                x, y_seg, y_ss = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if y_seg.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_self_supervised_pl: y_ss,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                # ===========================
                # update vars
                # ===========================
                _, loss_value = sess.run(
                    [train_op_test_time, loss_self_supervised],
                    feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.2f (%.3f sec for the last step)' %
                        (step + 1, loss_value, duration))

                    # ===========================
                    # print values of a parameter (to debug)
                    # ===========================
                    if exp_config.debug is True:
                        var_value = tf.compat.v1.get_collection(
                            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES
                        )[0].eval(
                            session=sess
                        )  # can add name if you only want parameters with a specific name
                        logging.info('value of one of the parameters %f' %
                                     var_value[0, 0, 0, 0])

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # compute the loss on the entire test set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:

                    logging.info('Test Data Eval:')
                    [test_loss_seg, test_dice_seg, test_loss_ss, test_acc_ss
                     ] = do_eval(sess, eval_loss_segmentation,
                                 eval_loss_self_supervised, images_pl,
                                 labels_segmentation_pl,
                                 labels_self_supervised_pl, training_pl, imts,
                                 gtts, exp_config.batch_size)

                    ts_summary_msg = sess.run(ts_summary_seg,
                                              feed_dict={
                                                  ts_error_seg: test_loss_seg,
                                                  ts_dice_seg: test_dice_seg
                                              })
                    summary_writer.add_summary(ts_summary_msg, step)
                    ts_summary_msg = sess.run(ts_summary_ss,
                                              feed_dict={
                                                  ts_error_ss: test_loss_ss,
                                                  ts_acc_ss: test_acc_ss
                                              })
                    summary_writer.add_summary(ts_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:

                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                step += 1

        sess.close()
예제 #7
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    train_on_all_data = config.train_on_all_data

    # Load data
    data = read_data.load_and_maybe_process_data(
        input_folder=config.data_root,
        preprocessing_folder=config.preprocessing_folder,
        mode=config.data_mode,
        size=config.image_size,
        target_resolution=config.target_resolution,
        force_overwrite=False)

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']

    if not train_on_all_data:
        images_val = data['images_val']
        labels_val = data['masks_val']

    if config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    #   if config.prob:   #if prob is not 0
    #       logging.info('Before data_augmentation the number of training images is:')
    #       logging.info(images_train.shape[0])
    #       #augmentation
    #       image_aug, label_aug = aug.augmentation_function(images_train,labels_train)

    #num_aug = image_aug.shape[0]
    # id images augmented will be b'0.0'
    #id_aug = np.zeros([num_aug,]).astype('|S9')
    #concatenate
    #id_train = np.concatenate((id__train,id_aug))
    #       images_train = np.concatenate((images_train,image_aug))
    #       labels_train = np.concatenate((labels_train,label_aug))

    #       logging.info('After data_augmentation the number of training images is:')
    #       logging.info(images_train.shape[0])
    #   else:
    #       logging.info('No data_augmentation. Number of training images is:')
    #       logging.info(images_train.shape[0])

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [config.batch_size] + list(
            config.image_size) + [1]
        mask_tensor_shape = [config.batch_size] + list(config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        if (config.experiment_name == 'unet2D_valid'
                or config.experiment_name == 'unet2D_same'
                or config.experiment_name == 'unet2D_same_mod'
                or config.experiment_name == 'unet2D_light'
                or config.experiment_name == 'Dunet2D_same_mod'
                or config.experiment_name == 'Dunet2D_same_mod2'
                or config.experiment_name == 'Dunet2D_same_mod3'):
            logits = model.inference(images_pl, config, training=training_pl)
        elif config.experiment_name == 'ENet':
            with slim.arg_scope(
                    model_structure.ENet_arg_scope(weight_decay=2e-4)):
                logits = model_structure.ENet(
                    images_pl,
                    num_classes=config.nlabels,
                    batch_size=config.batch_size,
                    is_training=True,
                    reuse=None,
                    num_initial_blocks=1,
                    stage_two_repeat=2,
                    skip_connections=config.skip_connections)
        else:
            logging.warning('invalid experiment_name!')

        logging.info('images_pl shape')
        logging.info(images_pl.shape)
        logging.info('labels_pl shape')
        logging.info(labels_pl.shape)
        logging.info('logits shape:')
        logging.info(logits.shape)
        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=config.nlabels,
                                    loss_type=config.loss_type,
                                    weight_decay=config.weight_decay
                                    )  # second output is unregularised loss

        # record how Total loss and weight decay change over time
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if config.momentum is not None:
            train_op = model.training_step(loss,
                                           config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=config.momentum)
        else:
            train_op = model.training_step(loss, config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        # loss and dice on a minibatch
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=config.nlabels,
                                     loss_type=config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        configP = tf.ConfigProto()
        configP.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        configP.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=configP)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(
                    images_train,
                    labels_train,
                    batch_size=config.batch_size,
                    augment_batch=config.augment_batch):

                if config.warmup_training:
                    if step < 50:
                        curr_lr = config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = config.learning_rate

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches)
                if y.shape[0] < config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 20 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.3f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if (step + 1) % config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl,
                                           images_train, labels_train,
                                           config.batch_size)

                    train_summary_msg = sess.run(train_summary,
                                                 feed_dict={
                                                     train_error_: train_loss,
                                                     train_dice_: train_dice
                                                 })
                    summary_writer.add_summary(train_summary_msg, step)

                    loss_history.append(train_loss)
                    if len(loss_history) > 5:
                        loss_history.pop(0)
                        loss_gradient = (loss_history[-5] -
                                         loss_history[-1]) / 2

                    logging.info('loss gradient is currently %f' %
                                 loss_gradient)

                    if config.schedule_lr and loss_gradient < config.schedule_gradient_threshold:
                        logging.warning('Reducing learning rate!')
                        curr_lr /= 10.0
                        logging.info('Learning rate changed to: %f' % curr_lr)

                        # reset loss history to give the optimisation some time to start decreasing again
                        loss_gradient = np.inf
                        loss_history = []

                    if train_loss <= last_train:  # best_train:
                        no_improvement_counter = 0
                        logging.info('Decrease in training error!')
                    else:
                        no_improvement_counter = no_improvement_counter + 1
                        logging.info(
                            'No improvment in training error for %d steps' %
                            no_improvement_counter)

                    last_train = train_loss

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % config.val_eval_frequency == 0:

                    checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)
                    # Evaluate against the training set.

                    if not train_on_all_data:

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss,
                         val_dice] = do_eval(sess, eval_loss, images_pl,
                                             labels_pl, training_pl,
                                             images_val, labels_val,
                                             config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(log_dir,
                                                     'model_best_dice.ckpt')
                            filelist = glob.glob(
                                os.path.join(log_dir, 'model_best_dice*'))
                            for file in filelist:
                                os.remove(file)
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                                % val_dice)

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            filelist = glob.glob(
                                os.path.join(log_dir, 'model_best_xent*'))
                            for file in filelist:
                                os.remove(file)
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                step += 1

            # end epoch
            if (epoch + 1) % config.epoch_freq == 0:
                curr_lr = curr_lr * 0.98
            logging.info('Learning rate: %f' % curr_lr)
        sess.close()
    data.close()
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(
        scribble_file=sys_config.project_root + exp_config.scribble_data,
        target_folder=log_dir,
        percent_full_sup=exp_config.percent_full_sup,
        scr_ratio=exp_config.length_ratio)
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir, 'recursion_{}_model.ckpt'.format(recursion))
                except:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir,
                        'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(
                    init_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(
                    init_step /
                    (len(base_data['images_train']) / exp_config.batch_size))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning(
                    '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
                )
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )

        if loaded_previous_recursion:
            logging.info(
                "Data file exists for recursion {} "
                "but checkpoints only present up to recursion {}".format(
                    recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")
            recursion_data = acdc_data.load_different_recursion(
                recursion_data, -1)
            recursion -= 1

        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.

        with tf.Graph().as_default():

            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32,
                                                shape=image_tensor_shape,
                                                name='images')
            labels_placeholder = tf.placeholder(tf.uint8,
                                                shape=mask_tensor_shape,
                                                name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm
             ] = model.loss(logits,
                            labels_placeholder,
                            nlabels=exp_config.nlabels,
                            loss_type=exp_config.loss_type,
                            weight_decay=exp_config.weight_decay
                            )  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)

            # Add to the Graph the Ops that calculate and apply gradients.
            if exp_config.momentum is not None:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder,
                                               momentum=exp_config.momentum)
            else:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder)

            # Add the Op to compare the logits to the labels during evaluation.
#            eval_loss = model.evaluation(logits,
#                                         labels_placeholder,
#                                         images_placeholder,
#                                         nlabels=exp_config.nlabels,
#                                         loss_type=exp_config.loss_type,
#                                         weak_supervision=True,
#                                         cnn_threshold=exp_config.cnn_threshold,
#                                         include_bg=True)
            eval_val_loss = model.evaluation(
                logits,
                labels_placeholder,
                images_placeholder,
                nlabels=exp_config.nlabels,
                loss_type=exp_config.loss_type,
                weak_supervision=True,
                cnn_threshold=exp_config.cnn_threshold,
                include_bg=False)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss',
                                                  val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge(
                [val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='train_error')
            train_error_summary = tf.summary.scalar('training_loss',
                                                    train_error_)

            train_dice_ = tf.placeholder(tf.float32,
                                         shape=[],
                                         name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice',
                                                   train_dice_)

            train_summary = tf.summary.merge(
                [train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)

            # Restore session
            #            crf_weights = []
            #            for v in tf.all_variables():
            #
            #                if v.name[0:4]=='bila':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='spat':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='comp':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #            restore_var = [v for v in tf.all_variables() if v.name not in crf_weights]
            #
            #            load_saver = tf.train.Saver(var_list=restore_var)
            #            load_saver.restore(sess, '/scratch_net/biwirender02/cany/basil/logdir/unet2D_ws_spot_blur/recursion_0_model.ckpt-5699')

            if continue_run:
                # Restore session
                saver.restore(sess, init_checkpoint_path)

            step = init_step
            curr_lr = exp_config.learning_rate

            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

            # random walk - if it already has been random walked it won't redo
            recursion_data = acdc_data.random_walk_epoch(
                recursion_data, exp_config.rw_beta, exp_config.rw_threshold,
                exp_config.random_walk)

            #get ground truths
            labels_train = np.array(recursion_data['random_walked'])

            for epoch in range(start_epoch, exp_config.max_epochs):
                if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0) \
                        or loaded_previous_recursion:
                    loaded_previous_recursion = False
                    #Have reached end of recursion
                    recursion_data = predict_next_gt(
                        data=recursion_data,
                        images_train=images_train,
                        images_placeholder=images_placeholder,
                        training_time_placeholder=training_time_placeholder,
                        logits=logits,
                        sess=sess)

                    recursion_data = postprocess_gt(
                        data=recursion_data,
                        images_train=images_train,
                        scribbles_train=scribbles_train)
                    recursion += 1
                    # random walk - if it already has been random walked it won't redo
                    recursion_data = acdc_data.random_walk_epoch(
                        recursion_data, exp_config.rw_beta,
                        exp_config.rw_threshold, exp_config.random_walk)
                    #get ground truths
                    labels_train = np.array(recursion_data['random_walked'])

                    #reinitialise savers - otherwise, no checkpoints will be saved for each recursion
                    saver = tf.train.Saver(max_to_keep=2)
                    saver_best_dice = tf.train.Saver(max_to_keep=2)
                    saver_best_xent = tf.train.Saver(max_to_keep=2)
                logging.info(
                    'Epoch {0} ({1} of {2} epochs for recursion {3})'.format(
                        epoch, 1 + epoch % exp_config.epochs_per_recursion,
                        exp_config.epochs_per_recursion, recursion))
                # for batch in iterate_minibatches(images_train,
                #                                  labels_train,
                #                                  batch_size=exp_config.batch_size,
                #                                  augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning even though the code has
                # crashed.

                for batch in BackgroundGenerator(
                        iterate_minibatches(
                            images_train,
                            labels_train,
                            batch_size=exp_config.batch_size,
                            augment_batch=exp_config.augment_batch)):

                    if exp_config.warmup_training:
                        if step < 50:
                            curr_lr = exp_config.learning_rate / 10.0
                        elif step == 50:
                            curr_lr = exp_config.learning_rate

                    start_time = time.time()

                    # batch = bgn_train.retrieve()
                    x, y = batch

                    # TEMPORARY HACK (to avoid incomplete batches
                    if y.shape[0] < exp_config.batch_size:
                        step += 1
                        continue

                    feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        learning_rate_placeholder: curr_lr,
                        training_time_placeholder: True
                    }

                    _, loss_value = sess.run([train_op, loss],
                                             feed_dict=feed_dict)

                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.
                        logging.info('Step %d: loss = %.6f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()


#                    if (step + 1) % exp_config.train_eval_frequency == 0:
#
#                        logging.info('Training Data Eval:')
#                        [train_loss, train_dice] = do_eval(sess,
#                                                           eval_loss,
#                                                           images_placeholder,
#                                                           labels_placeholder,
#                                                           training_time_placeholder,
#                                                           images_train,
#                                                           labels_train,
#                                                           exp_config.batch_size)
#
#                        train_summary_msg = sess.run(train_summary, feed_dict={train_error_: train_loss,
#                                                                               train_dice_: train_dice}
#                                                     )
#                        summary_writer.add_summary(train_summary_msg, step)
#
#                        loss_history.append(train_loss)
#                        if len(loss_history) > 5:
#                            loss_history.pop(0)
#                            loss_gradient = (loss_history[-5] - loss_history[-1]) / 2
#
#                        logging.info('loss gradient is currently %f' % loss_gradient)
#
#                        if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
#                            logging.warning('Reducing learning rate!')
#                            curr_lr /= 10.0
#                            logging.info('Learning rate changed to: %f' % curr_lr)
#
#                            # reset loss history to give the optimisation some time to start decreasing again
#                            loss_gradient = np.inf
#                            loss_history = []
#
#                        if train_loss <= last_train:  # best_train:
#                            logging.info('Decrease in training error!')
#                        else:
#                            logging.info('No improvement in training error for %d steps' % no_improvement_counter)
#
#                        last_train = train_loss

# Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(
                            log_dir,
                            'recursion_{}_model.ckpt'.format(recursion))
                        saver.save(sess, checkpoint_file, global_step=step)
                        # Evaluate against the training set.

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss, val_dice
                         ] = do_eval(sess, eval_val_loss, images_placeholder,
                                     labels_placeholder,
                                     training_time_placeholder, images_val,
                                     labels_val, exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_dice.ckpt'.format(
                                    recursion))
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - {} - '
                                'Saving recursion_{}_model_best_dice.ckpt'.
                                format(val_dice, recursion))

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_xent.ckpt'.format(
                                    recursion))
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - {} - '
                                'Saving recursion_{}_model_best_xent.ckpt'.
                                format(val_loss, recursion))

                    step += 1

    except Exception:
        raise
예제 #9
0
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_l2l)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading data...')
    if exp_config.train_dataset is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        data_brain_train = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=0,
            idx_end=20,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth,
            target_resolution=exp_config.target_resolution_brain)
        gttr = data_brain_train['labels']

        data_brain_val = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=20,
            idx_end=25,
            protocol='T1',
            size=exp_config.image_size,
            depth=exp_config.image_depth,
            target_resolution=exp_config.target_resolution_brain)
        gtvl = data_brain_val['labels']

    logging.info(
        'Training Labels: %s' % str(gttr.shape)
    )  # expected: [num_subjects, img_size_x, img_size_y, img_size_z]
    logging.info('Validation Labels: %s' % str(gtvl.shape))
    logging.info(
        '============================================================')

    # visualize downsampled volumes
    #    for subject_num in range(gttr.shape[0]):
    #        utils_vis.save_samples_downsampled(gttr[subject_num, ::2, :, :],
    #                                       savepath = log_dir + '/training_image_' + str(subject_num+1) + '.png')

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        true_labels_shape = [exp_config.batch_size] + list(
            exp_config.image_size)
        true_labels_pl = tf.placeholder(tf.uint8,
                                        shape=true_labels_shape,
                                        name='true_labels')

        # ================================================================
        # This will be a mask with all zeros in locations of pixels that we want to alter the labels of.
        # Multiply with this mask to have zero vectors for all those pixels.
        # ================================================================
        blank_masks_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [exp_config.nlabels]
        blank_masks_pl = tf.placeholder(tf.float32,
                                        shape=blank_masks_shape,
                                        name='blank_masks')

        # ================================================================
        # This will be a mask with all zeros in locations of pixels that we want to alter the labels of.
        # Multiply with this mask to have zero vectors for all those pixels.
        # ================================================================
        wrong_labels_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [exp_config.nlabels]
        wrong_labels_pl = tf.placeholder(tf.float32,
                                         shape=wrong_labels_shape,
                                         name='wrong_labels')

        # ================================================================
        # Training placeholder
        # ================================================================
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # make true labels 1-hot
        # ================================================================
        true_labels_1hot = tf.one_hot(true_labels_pl, depth=exp_config.nlabels)

        # ================================================================
        # Blank certain locations and write wrong labels in those locations
        # ================================================================
        noisy_labels_1hot = tf.math.multiply(true_labels_1hot,
                                             blank_masks_pl) + wrong_labels_pl

        # ================================================================
        # build the graph that computes predictions from the inference model
        # ================================================================
        autoencoded_logits, _, _ = model.predict_l2l(noisy_labels_1hot,
                                                     exp_config,
                                                     training_pl=training_pl)

        print('shape of input tensor: ',
              true_labels_pl.shape)  # (batch_size, 64, 256, 256)
        print('shape of input tensor converted to 1-hot: ',
              true_labels_1hot.shape)  # (batch_size, 64, 256, 256, 15)
        print('shape of predicted logits: ',
              autoencoded_logits.shape)  # (batch_size, 64, 256, 256, 15)

        # ================================================================
        # create a list of all vars that must be optimized wrt
        # ================================================================
        l2l_vars = []
        for v in tf.trainable_variables():
            print(v.name)
            l2l_vars.append(v)

        # ================================================================
        # add ops for calculation of the supervised training loss
        # ================================================================
        loss_op = model.loss(logits=autoencoded_logits,
                             labels=true_labels_1hot,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_l2l,
                             are_labels_1hot=True)
        tf.summary.scalar('loss', loss_op)

        # ================================================================
        # add optimization ops.
        # Create different ops according to the variables that must be trained
        # ================================================================
        print('creating training op...')
        train_op = model.training_step(loss_op,
                                       l2l_vars,
                                       exp_config.optimizer_handle,
                                       exp_config.learning_rate_l2l,
                                       update_bn_nontrainable_vars=True)

        # ================================================================
        # add ops for model evaluation
        # ================================================================
        print('creating eval op...')
        eval_loss = model.evaluation_l2l(logits=autoencoded_logits,
                                         labels=true_labels_1hot,
                                         labels_masked=noisy_labels_1hot,
                                         nlabels=exp_config.nlabels,
                                         loss_type=exp_config.loss_type_l2l,
                                         are_labels_1hot=True)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        print('creating summary op...')
        summary = tf.summary.merge_all()

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create saver
        # ================================================================
        saver = tf.train.Saver(max_to_keep=10)
        saver_best_dice = tf.train.Saver(max_to_keep=3)

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error')
        vl_error_summary = tf.summary.scalar('validation/loss', vl_error)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error')
        tr_error_summary = tf.summary.scalar('training/loss', tr_error)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of all variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of all variables:')
        for v in tf.trainable_variables():
            print(v.name)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            print(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        best_dice = 0

        # ================================================================
        # run training epochs
        # ================================================================
        while (step < exp_config.max_steps_l2l):

            if step % 1000 is 0:
                logging.info(
                    '============================================================'
                )
                logging.info('step %d' % step)

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches(gttr, exp_config.batch_size):

                start_time = time.time()
                true_labels, blank_masks, wrong_labels = batch

                # ===========================
                # avoid incomplete batches
                # ===========================
                if true_labels.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # ===========================
                # create the feed dict for this training iteration
                # ===========================
                feed_dict = {
                    true_labels_pl: true_labels,
                    blank_masks_pl: blank_masks,
                    wrong_labels_pl: wrong_labels,
                    training_pl: True
                }

                # ===========================
                # opt step
                # ===========================
                _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)

                # ===========================
                # compute the time for this mini-batch computation
                # ===========================
                duration = time.time() - start_time

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info(
                        'Step %d: loss = %.3f (%.3f sec for the last step)' %
                        (step + 1, loss, duration))

                    # ===========================
                    # Update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # Compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('Training Data Eval:')
                    train_loss, train_dice = do_eval(sess, eval_loss,
                                                     true_labels_pl,
                                                     blank_masks_pl,
                                                     wrong_labels_pl,
                                                     training_pl, gttr,
                                                     exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error: train_loss,
                                                  tr_dice: train_dice
                                              })

                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # Save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                    # at some frequency, visualize the noisy and clean segmentation pairs used for training the DAE
                    noisy_labels = sess.run(noisy_labels_1hot,
                                            feed_dict={
                                                true_labels_pl: true_labels,
                                                blank_masks_pl: blank_masks,
                                                wrong_labels_pl: wrong_labels
                                            })
                    noisy_labels = np.argmax(noisy_labels, axis=-1)

                    basepath = log_dir + '/training_data/iter' + str(step)
                    for zz in np.arange(20, 50, 10):
                        utils_vis.save_single_image(
                            true_labels[0, zz, :, :],
                            basepath + '_slice' + str(zz) + '_clean.png', 15,
                            True, 'tab20', False)
                        utils_vis.save_single_image(
                            noisy_labels[0, zz, :, :],
                            basepath + '_slice' + str(zz) + '_noisy.png', 15,
                            True, 'tab20', False)

                # ===========================
                # Evaluate the model periodically on a validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    logging.info('Validation Data Eval:')
                    val_loss, val_dice = do_eval(sess, eval_loss,
                                                 true_labels_pl,
                                                 blank_masks_pl,
                                                 wrong_labels_pl, training_pl,
                                                 gtvl, exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error: val_loss,
                                                  vl_dice: val_dice
                                              })

                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    if val_dice > best_dice:
                        best_dice = val_dice
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_dice.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice)

                step += 1

        sess.close()
def run_training(continue_run):

    # ============================
    # log experiment details
    # ============================
    logging.info(
        '============================================================')
    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 as otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ============================
    # Load data
    # ============================
    logging.info(
        '============================================================')
    logging.info('Loading training data from: ' + sys_config.project_data_root)
    data_tr = data_freiburg_numpy_to_hdf5.load_data(
        basepath=sys_config.project_data_root,
        idx_start=0,
        idx_end=19,
        train_test='train')
    images_tr = data_tr['images_train']
    labels_tr = data_tr['labels_train']

    logging.info(
        'Shape of training images: %s' % str(images_tr.shape)
    )  # expected: [img_size_z*num_images, img_size_x, vol_size_y, img_size_t, n_channels]
    logging.info(
        'Shape of training labels: %s' % str(labels_tr.shape)
    )  # expected: [img_size_z*num_images, img_size_x, vol_size_y, img_size_t]

    logging.info(
        '============================================================')
    logging.info('Loading validation data from: ' +
                 sys_config.project_data_root)
    data_vl = data_freiburg_numpy_to_hdf5.load_data(
        basepath=sys_config.project_data_root,
        idx_start=20,
        idx_end=24,
        train_test='validation')
    images_vl = data_vl['images_validation']
    labels_vl = data_vl['labels_validation']

    logging.info('Shape of validation images: %s' % str(images_vl.shape))
    logging.info('Shape of validation labels: %s' % str(labels_vl.shape))
    logging.info(
        '============================================================')

    if exp_config.nchannels is 1:
        logging.info(
            '============================================================')
        logging.info(
            'Only the proton density images (channel 0) will be used for the segmentation...'
        )
        logging.info(
            '============================================================')

    # visualize some training images and their labels
    visualize_images = False
    if visualize_images is True:
        for sub_tr in range(20):
            utils.save_sample_image_and_labels_across_z(
                images_tr[sub_tr * 32:(sub_tr + 1) * 32,
                          ...], labels_tr[sub_tr * 32:(sub_tr + 1) * 32, ...],
                log_dir + '/training_subject' + str(sub_tr))
            utils.save_sample_image_and_labels_across_t(
                images_tr[sub_tr * 32:(sub_tr + 1) * 32,
                          ...], labels_tr[sub_tr * 32:(sub_tr + 1) * 32, ...],
                log_dir + '/training_subject' + str(sub_tr))

    # ================================================================
    # build the TF graph
    # ================================================================
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders
        # ================================================================
        logging.info('Creating placeholders...')
        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [exp_config.nchannels]
        label_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)
        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=label_tensor_shape,
                                   name='labels')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # Build the graph that computes predictions from the inference model
        # ================================================================
        logits = model.inference(images_pl, exp_config.model_handle,
                                 training_pl)

        # ================================================================
        # Add ops for calculation of the training loss
        # ================================================================
        loss = model.loss(logits,
                          labels_pl,
                          exp_config.nlabels,
                          loss_type=exp_config.loss_type)

        # ================================================================
        # Add the loss to tensorboard for visualizing its evolution as training proceeds
        # ================================================================
        tf.summary.scalar('loss', loss)

        # ================================================================
        # Add optimization ops
        # ================================================================
        train_op = model.training_step(loss, exp_config.optimizer_handle,
                                       exp_config.learning_rate)

        # ================================================================
        # Add ops for model evaluation
        # ================================================================
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=exp_config.nlabels,
                                     loss_type=exp_config.loss_type)

        # ================================================================
        # Build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        summary = tf.summary.merge_all()

        # ================================================================
        # Add init ops
        # ================================================================
        init_op = tf.global_variables_initializer()

        # ================================================================
        # Find if any vars are uninitialized
        # ================================================================
        logging.info('Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create savers for each domain
        # ================================================================
        max_to_keep = 15
        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()

        # ================================================================
        # Create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the validation errors
        # ================================================================
        vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error')
        vl_error_summary = tf.summary.scalar('validation/loss', vl_error)
        vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice')
        vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice)
        vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary])

        # ================================================================
        # summaries of the training errors
        # ================================================================
        tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error')
        tr_error_summary = tf.summary.scalar('training/loss', tr_error)
        tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice')
        tr_dice_summary = tf.summary.scalar('training/dice', tr_dice)
        tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary])

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('initializing all variables...')
        sess.run(init_op)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        logging.info(
            '============================================================')
        logging.info('This is the list of uninitialized variables:')
        uninit_variables = sess.run(uninit_vars)
        for v in uninit_variables:
            print(v)

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring session from: %s' % init_checkpoint_path)
            saver.restore(sess, init_checkpoint_path)

        # ================================================================
        # ================================================================
        step = init_step
        best_dice = 0

        # ================================================================
        # run training epochs
        # ================================================================
        while step < exp_config.max_steps:

            logging.info(
                '============================================================')
            logging.info('Step %d' % step)

            for batch in iterate_minibatches(images_tr,
                                             labels_tr,
                                             batch_size=exp_config.batch_size):

                x, y = batch

                # ===========================
                # run training iteration
                # ===========================
                feed_dict = {images_pl: x, labels_pl: y, training_pl: True}
                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                # ===========================
                # write the summaries and print an overview fairly often
                # ===========================
                if (step + 1) % exp_config.summary_writing_frequency == 0:
                    logging.info('Step %d: loss = %.2f' %
                                 (step + 1, loss_value))
                    # ===========================
                    # update the events file
                    # ===========================
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # ===========================
                # compute the loss on the entire training set
                # ===========================
                if step % exp_config.train_eval_frequency == 0:
                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl, images_tr,
                                           labels_tr, exp_config.batch_size)

                    tr_summary_msg = sess.run(tr_summary,
                                              feed_dict={
                                                  tr_error: train_loss,
                                                  tr_dice: train_dice
                                              })
                    summary_writer.add_summary(tr_summary_msg, step)

                # ===========================
                # save a checkpoint periodically
                # ===========================
                if step % exp_config.save_frequency == 0:
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)

                # ===========================
                # evaluate the model on the validation set
                # ===========================
                if step % exp_config.val_eval_frequency == 0:
                    # ===========================
                    # Evaluate against the validation set
                    # ===========================
                    logging.info('Validation Data Eval:')
                    [val_loss,
                     val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl,
                                         training_pl, images_vl, labels_vl,
                                         exp_config.batch_size)

                    vl_summary_msg = sess.run(vl_summary,
                                              feed_dict={
                                                  vl_error: val_loss,
                                                  vl_dice: val_dice
                                              })
                    summary_writer.add_summary(vl_summary_msg, step)

                    # ===========================
                    # save model if the val dice is the best yet
                    # ===========================
                    if val_dice > best_dice:
                        best_dice = val_dice
                        best_file = os.path.join(log_dir,
                                                 'models/best_dice.ckpt')
                        saver_best_dice.save(sess, best_file, global_step=step)
                        logging.info(
                            'Found new average best dice on validation sets! - %f -  Saving model.'
                            % val_dice)

                step += 1

        sess.close()