def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl, labels_sd_vl, images_td_tr, images_td_vl): # ================================================================ # reset the graph built so far and build a new TF graph # ================================================================ tf.reset_default_graph() with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_num_uda) np.random.seed(exp_config.run_num_uda) # ================================================================ # create placeholders - segmentation net # ================================================================ images_sd_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_sd') images_td_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_td') labels_sd_pl = tf.placeholder(tf.uint8, shape=[exp_config.batch_size] + list(exp_config.image_size), name='labels_sd') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # ================================================================ images_sd_normalized, _ = model.normalize(images_sd_pl, exp_config, training_pl, scope_reuse=False) images_td_normalized, _ = model.normalize(images_td_pl, exp_config, training_pl, scope_reuse=True) # ================================================================ # get logit predictions from the segmentation network # ================================================================ predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized, exp_config, training_pl, scope_reuse=False) # ================================================================ # get all features from the segmentation network # ================================================================ seg_features_sd = model.get_all_features(images_sd_normalized, exp_config, scope_reuse=True) seg_features_td = model.get_all_features(images_td_normalized, exp_config, scope_reuse=True) # ================================================================ # resize all features to the same size # ================================================================ images_sd_features_resized = model.resize_features( [images_sd_normalized] + list(seg_features_sd), (64, 64), 'resize_sd_xfeat') images_td_features_resized = model.resize_features( [images_td_normalized] + list(seg_features_td), (64, 64), 'resize_td_xfeat') # ================================================================ # discriminator on features # ================================================================ d_logits_sd = model.discriminator(images_sd_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=False) d_logits_td = model.discriminator(images_td_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=True) # ================================================================ # add ops for calculation of the discriminator loss # ================================================================ d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_logits_sd), logits=d_logits_sd) d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(d_logits_td), logits=d_logits_td) loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td) tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op) # ================================================================ # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space # ================================================================ loss_g_op = model.loss_invariance(d_logits_td) tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op) # ================================================================ # add ops for calculation of the supervised segmentation loss # ================================================================ loss_seg_op = model.loss(predicted_seg_sd_logits, labels_sd_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op) # ================================================================ # total training loss for uda # ================================================================ loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op) # ================================================================ # merge all summaries # ================================================================ summary_scalars = tf.summary.merge_all() # ================================================================ # divide the vars into segmentation network, normalization network and the discriminator network # ================================================================ i2l_vars = [] normalization_vars = [] discriminator_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'discriminator' in var_name: discriminator_vars.append(v) # ================================================================ # add optimization ops # ================================================================ train_i2l_op = model.training_step( loss_total_op, i2l_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) train_discriminator_op = model.training_step( loss_d_op, discriminator_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) # ================================================================ # add ops for model evaluation # ================================================================ eval_loss = model.evaluation_i2l_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, d_logits_td, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # add ops for adding image summary to tensorboard # ================================================================ summary_images = model.write_image_summary_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, exp_config.nlabels) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ if exp_config.debug: print('creating summary op...') # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ if exp_config.debug: logging.info( 'Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a file writer object # This writes Summary protocol buffers to event files. # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it. # The class updates the file contents asynchronously. # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training. # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # create savers # ================================================================ saver = tf.train.Saver(var_list=i2l_vars) saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) # ================================================================ # summaries of the validation errors # ================================================================ vl_error_seg = tf.placeholder(tf.float32, shape=[], name='vl_error_seg') vl_error_seg_summary = tf.summary.scalar('validation/loss_seg', vl_error_seg) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_error_invariance = tf.placeholder(tf.float32, shape=[], name='vl_error_invariance') vl_error_invariance_summary = tf.summary.scalar( 'validation/loss_invariance', vl_error_invariance) vl_error_total = tf.placeholder(tf.float32, shape=[], name='vl_error_total') vl_error_total_summary = tf.summary.scalar('validation/loss_total', vl_error_total) vl_summary = tf.summary.merge([ vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary, vl_error_total_summary ]) # ================================================================ # summaries of the training errors # ================================================================ tr_error_seg = tf.placeholder(tf.float32, shape=[], name='tr_error_seg') tr_error_seg_summary = tf.summary.scalar('training/loss_seg', tr_error_seg) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_error_invariance = tf.placeholder(tf.float32, shape=[], name='tr_error_invariance') tr_error_invariance_summary = tf.summary.scalar( 'training/loss_invariance', tr_error_invariance) tr_error_total = tf.placeholder(tf.float32, shape=[], name='tr_error_total') tr_error_total_summary = tf.summary.scalar('training/loss_total', tr_error_total) tr_summary = tf.summary.merge([ tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary, tr_error_total_summary ]) # ================================================================ # freeze the graph before execution # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of uninitialized variables # ================================================================ uninit_variables = sess.run(uninit_vars) if exp_config.debug: logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') for v in uninit_variables: print(v) # ================================================================ # Restore the segmentation network parameters and the pre-trained i2i mapper parameters # ================================================================ if exp_config.train_from_scratch is False: logging.info( '============================================================') path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_lowest_loss.restore(sess, checkpoint_path) # ================================================================ # run training steps # ================================================================ step = 0 lowest_loss = 10000.0 validation_total_loss_list = [] while (step < exp_config.max_steps): # ================================================ # batches # ================================================ for batch in iterate_minibatches(images_sd=images_sd_tr, labels_sd=labels_sd_tr, images_td=images_td_tr, batch_size=exp_config.batch_size): x_sd, y_sd, x_td = batch # =========================== # define feed dict for this iteration # =========================== feed_dict = { images_sd_pl: x_sd, labels_sd_pl: y_sd, images_td_pl: x_td, training_pl: True } # ================================================ # update i2l and D successively # ================================================ sess.run(train_i2l_op, feed_dict=feed_dict) sess.run(train_discriminator_op, feed_dict=feed_dict) # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( '============== Updating summary at step %d ' % step) summary_writer.add_summary( sess.run(summary_scalars, feed_dict=feed_dict), step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('============== Training Data Eval:') train_loss_seg, train_dice, train_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_tr, labels_sd_tr, images_td_tr, exp_config.batch_size) # total training loss train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance # =========================== # update tensorboard summary of scalars # =========================== tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error_seg: train_loss_seg, tr_dice: train_dice, tr_error_invariance: train_loss_invariance, tr_error_total: train_total_loss }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: logging.info( '============== Periodically saving checkpoint:') checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('============== Validation Data Eval:') val_loss_seg, val_dice, val_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_vl, labels_sd_vl, images_td_vl, exp_config.batch_size) # total val loss val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance validation_total_loss_list.append(val_total_loss) # =========================== # update tensorboard summary of scalars # =========================== vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error_seg: val_loss_seg, vl_dice: val_dice, vl_error_invariance: val_loss_invariance, vl_error_total: val_total_loss }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # update tensorboard summary of images # =========================== summary_writer.add_summary( sess.run(summary_images, feed_dict={ images_sd_pl: x_sd, labels_sd_pl: y_sd, training_pl: False }), step) summary_writer.flush() # =========================== # save model if the val dice is the best yet # =========================== window_length = 5 if len(validation_total_loss_list) < window_length + 1: expo_moving_avg_loss_value = validation_total_loss_list[ -1] else: expo_moving_avg_loss_value = utils.exponential_moving_average( validation_total_loss_list, window=window_length)[-1] if expo_moving_avg_loss_value < lowest_loss: lowest_loss = val_total_loss lowest_loss_file = os.path.join( log_dir, 'models/lowest_loss.ckpt') saver_lowest_loss.save(sess, lowest_loss_file, global_step=step) logging.info( '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********' % step) # ================================================ # increment step # ================================================ step += 1 # ================================================================ # close tf session # ================================================================ sess.close() return 0
def run_training(log_dir, image, label, atlas, continue_run, log_dir_first_TD_subject=''): # ============================ # down sample the atlas - the losses will be evaluated in the downsampled space # ============================ atlas_downsampled = rescale(atlas, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') atlas_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot( atlas_downsampled, int(256 / exp_config.downsampling_factor_x)) label_onehot = utils.make_onehot(label, exp_config.nlabels) label_onehot_downsampled = rescale(label_onehot, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') label_onehot_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot( label_onehot_downsampled, int(256 / exp_config.downsampling_factor_x)) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ================================================================ # reset the graph built so far and build a new TF graph # ================================================================ tf.reset_default_graph() with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders - segmentation net # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # By setting the 'training_pl' to false directly, the update ops for the moments in the BN layer are not created at all. # This allows grouping the update ops together with the optimizer training, while training the normalizer - in case the normalizer has BN. # ================================================================ predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l( images_normalized, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # 3d prior # ================================================================ labels_3d_1hot_shape = [1] + list( exp_config.image_size_downsampled) + [exp_config.nlabels] # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder predicted_seg_1hot_3d_pl = tf.placeholder(tf.float32, shape=labels_3d_1hot_shape, name='predicted_labels_3d') # denoise the noisy segmentation _, predicted_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l( predicted_seg_1hot_3d_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] l2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'l2l_mapper' in var_name: l2l_vars.append(v) # ================================================================ # Make a list of trainable i2l vars. This will be used to compute gradients. # The other list contains trainable as well as non-trainable parameters. This is required for saving and loading all parameters, but runs into trouble when gradients are asked to be computed for non-trainable parameters # ================================================================ i2l_vars_trainable = [] for v in i2l_vars: if v.trainable is True: i2l_vars_trainable.append(v) # ================================================================ # add ops for calculation of the prior loss - wrt an atlas or the outputs of the DAE # ================================================================ prior_label_1hot_pl = tf.placeholder( tf.float32, shape=[exp_config.batch_size_downsampled] + list( (exp_config.image_size_downsampled[1], exp_config.image_size_downsampled[2])) + [exp_config.nlabels], name='labels_prior') # down sample the predicted logits predicted_seg_logits_expanded = tf.expand_dims(predicted_seg_logits, axis=0) # the 'upsample' function will actually downsample the predictions, as the scaling factors have been set appropriately predicted_seg_logits_downsampled = layers.bilinear_upsample3D_( predicted_seg_logits_expanded, name='downsampled_predictions', factor_x=1 / exp_config.downsampling_factor_x, factor_y=1 / exp_config.downsampling_factor_y, factor_z=1 / exp_config.downsampling_factor_z) predicted_seg_logits_downsampled = tf.squeeze( predicted_seg_logits_downsampled ) # the first axis was added only for the downsampling in 3d # compute the dice between the predictions and the prior in the downsampled space loss_op = model.loss(logits=predicted_seg_logits_downsampled, labels=prior_label_1hot_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_prior, mask_for_loss_within_mask=None, are_labels_1hot=True) tf.summary.scalar('tr_losses/loss', loss_op) # ================================================================ # one of the two prior losses will be used in the following manner: # the atlas prior will be used when the current prediction is deemed to be very far away from a reasonable solution # once a reasonable solution is reached, the dae prior will be used. # these 3d computations will be done outside the graph and will be passed via placeholders for logging in tensorboard # ================================================================ lambda_prior_atlas_pl = tf.placeholder(tf.float32, shape=[], name='lambda_prior_atlas') lambda_prior_dae_pl = tf.placeholder(tf.float32, shape=[], name='lambda_prior_dae') tf.summary.scalar('lambdas/prior_atlas', lambda_prior_atlas_pl) tf.summary.scalar('lambdas/prior_dae', lambda_prior_dae_pl) dice3d_prior_atlas_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_prior_atlas') dice3d_prior_dae_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_prior_dae') dice3d_gt_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_gt') tf.summary.scalar('dice3d/prior_atlas', dice3d_prior_atlas_pl) tf.summary.scalar('dice3d/prior_dae', dice3d_prior_dae_pl) tf.summary.scalar('dice3d/gt', dice3d_gt_pl) # ================================================================ # add optimization ops # ================================================================ if exp_config.debug: print('creating training op...') # create an instance of the required optimizer optimizer = exp_config.optimizer_handle(learning_rate=learning_rate_pl) # initialize variable holding the accumlated gradients and create a zero-initialisation op accumulated_gradients = [ tf.Variable(tf.zeros_like(var.initialized_value()), trainable=False) for var in i2l_vars_trainable ] # accumulated gradients init op accumulated_gradients_zero_op = [ ac.assign(tf.zeros_like(ac)) for ac in accumulated_gradients ] # calculate gradients and define accumulation op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): gradients = optimizer.compute_gradients( loss_op, var_list=i2l_vars_trainable) # compute_gradients return a list of (gradient, variable) pairs. accumulate_gradients_op = [ ac.assign_add(gg[0]) for ac, gg in zip(accumulated_gradients, gradients) ] # define the gradient mean op num_accumulation_steps_pl = tf.placeholder( dtype=tf.float32, name='num_accumulation_steps') accumulated_gradients_mean_op = [ ag.assign(tf.divide(ag, num_accumulation_steps_pl)) for ag in accumulated_gradients ] # reassemble the gradients in the [value, var] format and do define train op final_gradients = [(ag, gg[1]) for ag, gg in zip(accumulated_gradients, gradients)] train_op = optimizer.apply_gradients(final_gradients) # ================================================================ # sequence of running opt ops: # 1. at the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders) # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs # 3. at the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients. # ================================================================ # ================================================================ # previous train_op without accumulation of gradients # ================================================================ # train_op = model.training_step(loss_op, normalization_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars = True) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ if exp_config.debug: print('creating summary op...') # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ if exp_config.debug: logging.info( 'Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the training errors # ================================================================ prior_dae_dice = tf.placeholder(tf.float32, shape=[], name='prior_dae_dice') prior_dae_dice_summary = tf.summary.scalar('test_img/prior_dae_dice', prior_dae_dice) prior_dae_output_dice_wrt_gt = tf.placeholder( tf.float32, shape=[], name='prior_dae_output_dice_wrt_gt') prior_dae_output_dice_wrt_gt_summary = tf.summary.scalar( 'test_img/prior_dae_output_dice_wrt_gt', prior_dae_output_dice_wrt_gt) prior_atlas_dice = tf.placeholder(tf.float32, shape=[], name='prior_atlas_dice') prior_atlas_dice_summary = tf.summary.scalar( 'test_img/prior_atlas_dice', prior_atlas_dice) prior_dae_atlas_dice_ratio = tf.placeholder( tf.float32, shape=[], name='prior_dae_atlas_dice_ratio') prior_dae_atlas_dice_ratio_summary = tf.summary.scalar( 'test_img/prior_dae_atlas_dice_ratio', prior_dae_atlas_dice_ratio) prior_dice = tf.placeholder(tf.float32, shape=[], name='prior_dice') prior_dice_summary = tf.summary.scalar('test_img/prior_dice', prior_dice) gt_dice = tf.placeholder(tf.float32, shape=[], name='gt_dice') gt_dice_summary = tf.summary.scalar('test_img/gt_dice', gt_dice) # ================================================================ # create savers # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_l2l = tf.train.Saver(var_list=l2l_vars) saver_test_data = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) saver_best_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) # ================================================================ # add operations to compute dice between two 3d volumes # ================================================================ pred_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='pred_3d') labl_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='labl_3d') atls_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='atls_3d') dice_3d_op_dae = losses.compute_dice_3d_without_batch_axis( prediction=pred_3d_1hot_pl, labels=labl_3d_1hot_pl) dice_3d_op_atlas = losses.compute_dice_3d_without_batch_axis( prediction=pred_3d_1hot_pl, labels=atls_3d_1hot_pl) # ================================================================ # freeze the graph before execution # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of uninitialized variables # ================================================================ uninit_variables = sess.run(uninit_vars) if exp_config.debug: logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') for v in uninit_variables: print(v) # ================================================================ # Restore the segmentation network parameters and the pre-trained i2i mapper parameters # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the prior network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_l2l.restore(sess, checkpoint_path) # ================================================================ # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters # ================================================================ if log_dir_first_TD_subject is not '': logging.info( '============================================================') path_to_model = log_dir_first_TD_subject + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_score.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_test_data.restore(sess, checkpoint_path) max_steps_tta = exp_config.max_steps else: max_steps_tta = 5 * exp_config.max_steps # run the adaptation of the 1st TD subject for longer # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring normalization module from: %s' % init_checkpoint_path) saver_test_data.restore(sess, init_checkpoint_path) # ================================================================ # run training epochs # ================================================================ step = init_step best_score = 0.0 while (step < max_steps_tta): # ================================================ # After every some epochs, # Get the prediction for the entire volume, evaluate it using the DAE. # Now, decide whether to use the DAE output or the atlas as the ground truth for the next update. # ================================================ if (step == init_step) or (step % exp_config.check_ood_frequency is 0): # ================== # 1. compute the current 3d segmentation prediction # ================== y_pred_soft = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): y_pred_soft.append( sess.run(predicted_seg_softmax, feed_dict={ images_pl: batch, training_pl: False })) y_pred_soft = np.squeeze(np.array(y_pred_soft)).astype(float) y_pred_soft = np.reshape(y_pred_soft, [ -1, y_pred_soft.shape[2], y_pred_soft.shape[3], y_pred_soft.shape[4] ]) # ================== # 2. downsample it. Let's call this guy 'A' # ================== y_pred_soft_downsampled = rescale(y_pred_soft, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant').astype( np.float32) # ================== # 3. pass the downsampled prediction through the DAE and get its output 'B' # ================== feed_dict = { predicted_seg_1hot_3d_pl: np.expand_dims(y_pred_soft_downsampled, axis=0) } y_pred_noisy_denoised_softmax = np.squeeze( sess.run( predicted_seg_softmax_3d_noisy_autoencoded_softmax, feed_dict=feed_dict)).astype(np.float16) y_pred_noisy_denoised = np.argmax( y_pred_noisy_denoised_softmax, axis=-1) # ================== # 4. compute the dice between: # a. 'A' (seg network prediction downsampled) and 'B' (dae network output) # b. 'B' (dae network output) and downsampled gt labels (for debugging, to see if the dae output is close to the gt.) # c. 'A' (seg network prediction downsampled) and 'C' (downsampled atlas) # ================== dAB = sess.run(dice_3d_op_dae, feed_dict={ pred_3d_1hot_pl: y_pred_soft_downsampled, labl_3d_1hot_pl: y_pred_noisy_denoised_softmax }) dBgt = sess.run(dice_3d_op_dae, feed_dict={ pred_3d_1hot_pl: y_pred_noisy_denoised_softmax, labl_3d_1hot_pl: label_onehot_downsampled }) dAC = sess.run(dice_3d_op_atlas, feed_dict={ pred_3d_1hot_pl: y_pred_soft_downsampled, atls_3d_1hot_pl: atlas_downsampled }) # ================== # 5. compute the ratio dice(AB) / dice(AC). pass the ratio through a threshold and decide whether to use the DAE or the atlas as the prior # ================== ratio_dice = dAB / (dAC + 1e-5) if exp_config.use_gt_for_tta is True: target_labels_for_this_epoch = label_onehot_downsampled prr = dBgt elif (ratio_dice > exp_config.dae_atlas_ratio_threshold) and ( dAC > exp_config.min_atlas_dice): target_labels_for_this_epoch = y_pred_noisy_denoised_softmax prr = dAB else: target_labels_for_this_epoch = atlas_downsampled prr = dAC # ================== # update losses on tensorboard # ================== summary_writer.add_summary( sess.run(prior_dae_dice_summary, feed_dict={prior_dae_dice: dAB}), step) summary_writer.add_summary( sess.run(prior_dae_output_dice_wrt_gt_summary, feed_dict={prior_dae_output_dice_wrt_gt: dBgt}), step) summary_writer.add_summary( sess.run(prior_atlas_dice_summary, feed_dict={prior_atlas_dice: dAC}), step) summary_writer.add_summary( sess.run( prior_dae_atlas_dice_ratio_summary, feed_dict={prior_dae_atlas_dice_ratio: ratio_dice}), step) summary_writer.add_summary( sess.run(prior_dice_summary, feed_dict={prior_dice: prr}), step) # ================== # save best model so far # ================== if best_score < prr: best_score = prr best_file = os.path.join(log_dir, 'models/best_score.ckpt') saver_best_loss.save(sess, best_file, global_step=step) logging.info( 'Found new best score (%f) at step %d - Saving model.' % (best_score, step)) # ================== # dice wrt gt # ================== y_pred = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): y_pred.append( sess.run(predicted_seg, feed_dict={ images_pl: batch, training_pl: False })) y_pred = np.squeeze(np.array(y_pred)).astype(float) y_pred = np.reshape(y_pred, [-1, y_pred.shape[2], y_pred.shape[3]]) dice_wrt_gt = met.f1_score(label.flatten(), y_pred.flatten(), average=None) summary_writer.add_summary( sess.run(gt_dice_summary, feed_dict={gt_dice: np.mean(dice_wrt_gt[1:])}), step) # ================== # visualize results # ================== if step % exp_config.vis_frequency is 0: # =========================== # save checkpoint # =========================== logging.info( '=============== Saving checkkpoint at step %d ... ' % step) checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver_test_data.save(sess, checkpoint_file, global_step=step) y_pred_noisy_denoised_upscaled = utils.crop_or_pad_volume_to_size_along_x( rescale(y_pred_noisy_denoised, [ exp_config.downsampling_factor_x, exp_config.downsampling_factor_y, exp_config.downsampling_factor_z ], order=0, preserve_range=True, multichannel=False, mode='constant'), image.shape[0]).astype(np.uint8) x_norm = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): x = batch x_norm.append( sess.run(images_normalized, feed_dict={ images_pl: x, training_pl: False })) x_norm = np.squeeze(np.array(x_norm)).astype(float) x_norm = np.reshape(x_norm, [-1, x_norm.shape[2], x_norm.shape[3]]) utils_vis.save_sample_results( x=image, x_norm=x_norm, x_diff=x_norm - image, y=y_pred, y_pred_dae=y_pred_noisy_denoised_upscaled, at=np.argmax(atlas, axis=-1), gt=label, savepath=log_dir + '/results/visualize_images/step' + str(step) + '.png') # ================================================ # Part of training ops sequence: # 1. At the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders) # ================================================ sess.run(accumulated_gradients_zero_op) num_accumulation_steps = 0 # ================================================ # batches # ================================================ for batch in iterate_minibatches_images_and_downsampled_labels( images=image, batch_size=exp_config.batch_size, labels_downsampled=target_labels_for_this_epoch, batch_size_downsampled=exp_config.batch_size_downsampled): x, y = batch # =========================== # define feed dict for this iteration # =========================== feed_dict = { images_pl: x, prior_label_1hot_pl: y, learning_rate_pl: exp_config.learning_rate, training_pl: True } # ================================================ # Part of training ops sequence: # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs # ================================================ sess.run(accumulate_gradients_op, feed_dict=feed_dict) num_accumulation_steps = num_accumulation_steps + 1 step += 1 # ================================================ # Part of training ops sequence: # 3. At the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl # ================================================ sess.run( accumulated_gradients_mean_op, feed_dict={num_accumulation_steps_pl: num_accumulation_steps}) # ================================================================ # sequence of running opt ops: # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients. # ================================================================ sess.run(train_op, feed_dict=feed_dict) # ================================================================ # ================================================================ sess.close() # ================================================================ # ================================================================ gc.collect() return 0
def predict_segmentation(subject_name, image, normalize=True, post_process=False): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + list(exp_config.image_size) + [1], name='images') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l( images_normalized, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # 3d prior # ================================================================ labels_3d_shape = [1] + list(exp_config.image_size_downsampled) # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder predicted_seg_3d_pl = tf.placeholder(tf.uint8, shape=labels_3d_shape, name='true_labels_3d') predicted_seg_1hot_3d_pl = tf.one_hot(predicted_seg_3d_pl, depth=exp_config.nlabels) # denoise the noisy segmentation _, pred_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l( predicted_seg_1hot_3d_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] l2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'l2l_mapper' in var_name: l2l_vars.append(v) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_normalizer = tf.train.Saver(var_list=normalization_vars) saver_l2l = tf.train.Saver(var_list=l2l_vars) # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the prior network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_l2l.restore(sess, checkpoint_path) # ================================================================ # Make predictions for the image at the resolution of the image after pre-processing # ================================================================ mask_predicted = [] mask_predicted_soft = [] img_normalized = [] for b_i in range(0, image.shape[0], 1): X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1) mask_predicted.append( sess.run(predicted_seg, feed_dict={images_pl: X})) mask_predicted_soft.append( sess.run(predicted_seg_softmax, feed_dict={images_pl: X})) img_normalized.append( sess.run(images_normalized, feed_dict={images_pl: X})) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) mask_predicted_soft = np.squeeze( np.array(mask_predicted_soft)).astype(float) img_normalized = np.squeeze(np.array(img_normalized)).astype(float) # ================================================================ # downsample predicted mask and pass it through the DAE # ================================================================ if post_process is True: # downsample mask_predicted_soft mask_predicted_soft_downsampled = rescale(mask_predicted_soft, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') mask_predicted_downsampled = np.argmax( mask_predicted_soft_downsampled, axis=-1) # pass the downsampled prediction through the DAE mask_predicted_denoised = mask_predicted_downsampled for _ in range(exp_config.dae_post_process_runs): feed_dict = { predicted_seg_3d_pl: np.expand_dims(mask_predicted_denoised, axis=0) } y_pred_noisy_denoised_softmax = np.squeeze( sess.run(pred_seg_softmax_3d_noisy_autoencoded_softmax, feed_dict=feed_dict)).astype(np.float16) mask_predicted_denoised = np.argmax( y_pred_noisy_denoised_softmax, axis=-1) # upsample the denoised prediction mask_predicted = rescale(mask_predicted_denoised, [ exp_config.downsampling_factor_x, exp_config.downsampling_factor_y, exp_config.downsampling_factor_z ], order=0, preserve_range=True, multichannel=False, mode='constant').astype(np.uint8) sess.close() return mask_predicted, img_normalized
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=1040, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] if exp_config.train_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'IXI': logging.info('Reading IXI images...') logging.info('Data root directory: ' + sys_config.orig_data_root_ixi) data_brain_train = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=0, idx_end=12, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=12, idx_end=17, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module will be adapted for each test image # ================================================================ images_normalized, _ = model.normalize(images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, _, _ = model.predict_i2l(images_normalized, exp_config, training_pl=training_pl) print('shape of inputs: ', images_pl.shape) # (batch_size, 256, 256, 1) print('shape of logits: ', logits.shape) # (batch_size, 256, 256, nlabels) # ================================================================ # create a list of all vars that must be optimized wrt # ================================================================ i2l_vars = [] for v in tf.trainable_variables(): i2l_vars.append(v) # ================================================================ # add ops for calculation of the supervised training loss # ================================================================ loss_op = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('loss', loss_op) # ================================================================ # add optimization ops. # Create different ops according to the variables that must be trained # ================================================================ print('creating training op...') train_op = model.training_step(loss_op, i2l_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars=True) # ================================================================ # add ops for model evaluation # ================================================================ print('creating eval op...') eval_loss = model.evaluation_i2l(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ print('creating summary op...') summary = tf.summary.merge_all() # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create saver # ================================================================ saver = tf.train.Saver(max_to_keep=10) saver_best_dice = tf.train.Saver(max_to_keep=3) # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of all variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of all variables:') for v in tf.trainable_variables(): print(v.name) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_dice = 0 # ================================================================ # run training epochs # ================================================================ while (step < exp_config.max_steps): if step % 1000 is 0: logging.info( '============================================================' ) logging.info('step %d' % step) # ================================================ # batches # ================================================ for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size, train_or_eval='train'): curr_lr = exp_config.learning_rate start_time = time.time() x, y = batch # =========================== # avoid incomplete batches # =========================== if y.shape[0] < exp_config.batch_size: step += 1 continue # =========================== # create the feed dict for this training iteration # =========================== feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } # =========================== # opt step # =========================== _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.3f (%.3f sec for the last step)' % (step + 1, loss, duration)) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') train_loss, train_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('Validation Data Eval:') val_loss, val_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()
def predict_segmentation(subject_name, image, normalize=True): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + list(exp_config.image_size) + [1], name='images') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, softmax, preds = model.predict_i2l(images_normalized, exp_config, training_pl=tf.constant( False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name i2l_vars.append(v) if 'image_normalizer' in var_name: normalization_vars.append(v) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_normalizer = tf.train.Saver(var_list=normalization_vars) # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the normalization network parameters # ================================================================ if normalize is True: logging.info( '============================================================') path_to_model = os.path.join( sys_config.log_root, exp_config.expname_normalizer ) + '/subject_' + subject_name + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_score.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_normalizer.restore(sess, checkpoint_path) logging.info( '============================================================') # ================================================================ # Make predictions for the image at the resolution of the image after pre-processing # ================================================================ mask_predicted = [] img_normalized = [] for b_i in range(0, image.shape[0], 1): X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1) mask_predicted.append(sess.run(preds, feed_dict={images_pl: X})) img_normalized.append( sess.run(images_normalized, feed_dict={images_pl: X})) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) img_normalized = np.squeeze(np.array(img_normalized)).astype(float) sess.close() return mask_predicted, img_normalized
def predict_segmentation(image): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + [exp_config.image_size[0]] + [exp_config.image_size[1]] + [1], name='images') # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, softmax, preds = model.predict_i2l(images_pl, exp_config, training_pl=tf.constant( False, dtype=tf.bool)) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver() # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # predict segmentation # ================================================================ X = np.expand_dims(np.expand_dims(image, axis=-1), axis=0) mask_predicted = sess.run(preds, feed_dict={images_pl: X}) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) sess.close() return mask_predicted