def get_latest_checkpoint_and_log(logdir, filename): init_checkpoint_path = utils.get_latest_model_checkpoint_path( logdir, filename) logging.info('Checkpoint path: %s' % init_checkpoint_path) last_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) logging.info('Latest step was: %d' % last_step) return init_checkpoint_path
def main(fs_exp_config, slices, test): # Load data data = load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=fs_exp_config.data_mode, size=fs_exp_config.image_size, target_resolution=fs_exp_config.target_resolution, force_overwrite=False ) # Get images batch_size = len(slices) if test: slices = slices[slices < len(data['images_test'])] images = data['images_test'][slices, ...] prefix = 'test' else: slices = slices[slices < len(data['images_train'])] images = data['images_train'][slices, ...] prefix = 'train' image_tensor_shape = [batch_size] + list(fs_exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') feed_dict = { images_pl: np.expand_dims(images, -1), } #Get full supervision prediction mask_pl, softmax_pl = model.predict(images_pl, fs_exp_config.model_handle, fs_exp_config.nlabels) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(fs_model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) fs_predictions, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) for i in range(batch_size): print_coloured(fs_predictions[i, ...], filepath=OUTPUT_FOLDER, filename='{}{}_fs_pred'.format(prefix, slices[i]))
def run_training(continue_run, log_dir): logging.info('===== RUNNING EXPERIMENT ========') logging.info(exp_config.experiment_name) logging.info('=================================') init_step = 0 if continue_run: logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) log_dir += '_cont' except: logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...') continue_run = False init_step = 0 logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') # import data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list = exp_config.label_list, offset=exp_config.offset, rescale_to_one=exp_config.rescale_to_one, force_overwrite=False ) # extract images and indices of source/target images for the training and validation set images_train, source_images_train_ind, target_images_train_ind,\ images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices( data, exp_config.source_field_strength, exp_config.target_field_strength) generator = exp_config.generator discriminator = exp_config.discriminator z_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=source_images_train_ind) x_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=target_images_train_ind) with tf.Graph().as_default(): # Generate placeholders for the images and labels. im_s = exp_config.image_size training_placeholder = tf.placeholder(tf.bool, name='training_phase') if exp_config.use_generator_input_noise: noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1) else: noise_in_gen_pl = None # target image batch x_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='x') # source image batch z_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='z') # generated fake image batch x_pl_ = generator(z_pl, noise_in_gen_pl, training_placeholder) # difference between generated and source images diff_img_pl = x_pl_ - z_pl # visualize the images by showing one slice of them in the z direction tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(x_pl_, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(x_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_zs', tf_utils.put_kernels_on_grid3d(z_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_difference_gx-x', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='centered', cutoff_abs=exp_config.diff_threshold)) # output of the discriminator for real image d_pl = discriminator(x_pl, training_placeholder, scope_reuse=False) # output of the discriminator for fake image d_pl_ = discriminator(x_pl_, training_placeholder, scope_reuse=True) d_hat = None x_hat = None if exp_config.improved_training: epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = epsilon * x_pl + (1 - epsilon) * x_pl_ d_hat = discriminator(x_hat, training_placeholder, scope_reuse=True) dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl)) # nr means no regularization, meaning the loss without the regularization term discriminator_train_op, generator_train_op, \ disc_loss_pl, gen_loss_pl, \ disc_loss_nr_pl, gen_loss_nr_pl = gan_model.training_ops(d_pl, d_pl_, optimizer_handle=exp_config.optimizer_handle, learning_rate=exp_config.learning_rate, l1_img_dist=dist_l1, w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1, w_reg_gen_l1=exp_config.w_reg_gen_l1, w_reg_disc_l1=exp_config.w_reg_disc_l1, w_reg_gen_l2=exp_config.w_reg_gen_l2, w_reg_disc_l2=exp_config.w_reg_disc_l2, d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale) # Build the operation for clipping the discriminator weights d_clip_op = gan_model.clip_op() # Put L1 distance of generated image and original image on summary dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1) # Build the summary Tensor based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # validation summaries val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss') disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl) val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss') gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl) val_summary_op = tf.summary.merge([disc_val_summary_op, gen_val_summary_op]) # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a savers for writing training checkpoints. saver_latest = tf.train.Saver(max_to_keep=3) saver_best_disc = tf.train.Saver(max_to_keep=3) # disc loss is scaled negative EM distance # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # Create a session for running Ops on the Graph. sess = tf.Session(config=config) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver_latest.restore(sess, init_checkpoint_path) # initialize value of lowest (i. e. best) discriminator loss best_d_loss = np.inf for step in range(init_step, 1000000): start_time = time.time() # discriminator training iterations d_iters = 5 if step % 500 == 0 or step < 25: d_iters = 100 for _ in range(d_iters): x = next(x_sampler_train) z = next(z_sampler_train) # train discriminator sess.run(discriminator_train_op, feed_dict={z_pl: z, x_pl: x, training_placeholder: True}) if not exp_config.improved_training: sess.run(d_clip_op) elapsed_time = time.time() - start_time # train generator x = next(x_sampler_train) # why not sample a new x?? z = next(z_sampler_train) sess.run(generator_train_op, feed_dict={z_pl: z, x_pl: x, training_placeholder: True}) if step % exp_config.update_tensorboard_frequency == 0: x = next(x_sampler_train) z = next(z_sampler_train) g_loss_train, d_loss_train, summary_str = sess.run( [gen_loss_nr_pl, disc_loss_nr_pl, summary_op], feed_dict={z_pl: z, x_pl: x, training_placeholder: False}) summary_writer.add_summary(summary_str, step) summary_writer.flush() logging.info("[Step: %d], generator loss: %g, discriminator_loss: %g" % (step, g_loss_train, d_loss_train)) logging.info(" - elapsed time for one step: %f secs" % elapsed_time) if step % exp_config.validation_frequency == 0: z_sampler_val = iterate_minibatches_endlessly(images_val, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=source_images_val_ind) x_sampler_val = iterate_minibatches_endlessly(images_val, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=target_images_val_ind) # evaluate the validation batch with batch_size images (from each domain) at a time g_loss_val_list = [] d_loss_val_list = [] for _ in range(exp_config.num_val_batches): x = next(x_sampler_val) z = next(z_sampler_val) g_loss_val, d_loss_val = sess.run( [gen_loss_nr_pl, disc_loss_nr_pl], feed_dict={z_pl: z, x_pl: x, training_placeholder: False}) g_loss_val_list.append(g_loss_val) d_loss_val_list.append(d_loss_val) g_loss_val_avg = np.mean(g_loss_val_list) d_loss_val_avg = np.mean(d_loss_val_list) validation_summary_str = sess.run(val_summary_op, feed_dict={val_disc_loss_pl: d_loss_val_avg, val_gen_loss_pl: g_loss_val_avg} ) summary_writer.add_summary(validation_summary_str, step) summary_writer.flush() # save best variables (if discriminator loss is the lowest yet) if d_loss_val_avg <= best_d_loss: best_d_loss = d_loss_val_avg best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt') saver_best_disc.save(sess, best_file, global_step=step) logging.info('Found new best discriminator loss on validation set! - %f - Saving model_best_d_loss.ckpt' % best_d_loss) logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg)) # Write the summaries and print an overview fairly often. if step % exp_config.save_frequency == 0: saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step)
def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl, labels_sd_vl, images_td_tr, images_td_vl): # ================================================================ # reset the graph built so far and build a new TF graph # ================================================================ tf.reset_default_graph() with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_num_uda) np.random.seed(exp_config.run_num_uda) # ================================================================ # create placeholders - segmentation net # ================================================================ images_sd_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_sd') images_td_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_td') labels_sd_pl = tf.placeholder(tf.uint8, shape=[exp_config.batch_size] + list(exp_config.image_size), name='labels_sd') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # ================================================================ images_sd_normalized, _ = model.normalize(images_sd_pl, exp_config, training_pl, scope_reuse=False) images_td_normalized, _ = model.normalize(images_td_pl, exp_config, training_pl, scope_reuse=True) # ================================================================ # get logit predictions from the segmentation network # ================================================================ predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized, exp_config, training_pl, scope_reuse=False) # ================================================================ # get all features from the segmentation network # ================================================================ seg_features_sd = model.get_all_features(images_sd_normalized, exp_config, scope_reuse=True) seg_features_td = model.get_all_features(images_td_normalized, exp_config, scope_reuse=True) # ================================================================ # resize all features to the same size # ================================================================ images_sd_features_resized = model.resize_features( [images_sd_normalized] + list(seg_features_sd), (64, 64), 'resize_sd_xfeat') images_td_features_resized = model.resize_features( [images_td_normalized] + list(seg_features_td), (64, 64), 'resize_td_xfeat') # ================================================================ # discriminator on features # ================================================================ d_logits_sd = model.discriminator(images_sd_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=False) d_logits_td = model.discriminator(images_td_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=True) # ================================================================ # add ops for calculation of the discriminator loss # ================================================================ d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_logits_sd), logits=d_logits_sd) d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(d_logits_td), logits=d_logits_td) loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td) tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op) # ================================================================ # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space # ================================================================ loss_g_op = model.loss_invariance(d_logits_td) tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op) # ================================================================ # add ops for calculation of the supervised segmentation loss # ================================================================ loss_seg_op = model.loss(predicted_seg_sd_logits, labels_sd_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op) # ================================================================ # total training loss for uda # ================================================================ loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op) # ================================================================ # merge all summaries # ================================================================ summary_scalars = tf.summary.merge_all() # ================================================================ # divide the vars into segmentation network, normalization network and the discriminator network # ================================================================ i2l_vars = [] normalization_vars = [] discriminator_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'discriminator' in var_name: discriminator_vars.append(v) # ================================================================ # add optimization ops # ================================================================ train_i2l_op = model.training_step( loss_total_op, i2l_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) train_discriminator_op = model.training_step( loss_d_op, discriminator_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) # ================================================================ # add ops for model evaluation # ================================================================ eval_loss = model.evaluation_i2l_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, d_logits_td, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # add ops for adding image summary to tensorboard # ================================================================ summary_images = model.write_image_summary_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, exp_config.nlabels) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ if exp_config.debug: print('creating summary op...') # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ if exp_config.debug: logging.info( 'Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a file writer object # This writes Summary protocol buffers to event files. # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it. # The class updates the file contents asynchronously. # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training. # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # create savers # ================================================================ saver = tf.train.Saver(var_list=i2l_vars) saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) # ================================================================ # summaries of the validation errors # ================================================================ vl_error_seg = tf.placeholder(tf.float32, shape=[], name='vl_error_seg') vl_error_seg_summary = tf.summary.scalar('validation/loss_seg', vl_error_seg) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_error_invariance = tf.placeholder(tf.float32, shape=[], name='vl_error_invariance') vl_error_invariance_summary = tf.summary.scalar( 'validation/loss_invariance', vl_error_invariance) vl_error_total = tf.placeholder(tf.float32, shape=[], name='vl_error_total') vl_error_total_summary = tf.summary.scalar('validation/loss_total', vl_error_total) vl_summary = tf.summary.merge([ vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary, vl_error_total_summary ]) # ================================================================ # summaries of the training errors # ================================================================ tr_error_seg = tf.placeholder(tf.float32, shape=[], name='tr_error_seg') tr_error_seg_summary = tf.summary.scalar('training/loss_seg', tr_error_seg) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_error_invariance = tf.placeholder(tf.float32, shape=[], name='tr_error_invariance') tr_error_invariance_summary = tf.summary.scalar( 'training/loss_invariance', tr_error_invariance) tr_error_total = tf.placeholder(tf.float32, shape=[], name='tr_error_total') tr_error_total_summary = tf.summary.scalar('training/loss_total', tr_error_total) tr_summary = tf.summary.merge([ tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary, tr_error_total_summary ]) # ================================================================ # freeze the graph before execution # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of uninitialized variables # ================================================================ uninit_variables = sess.run(uninit_vars) if exp_config.debug: logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') for v in uninit_variables: print(v) # ================================================================ # Restore the segmentation network parameters and the pre-trained i2i mapper parameters # ================================================================ if exp_config.train_from_scratch is False: logging.info( '============================================================') path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_lowest_loss.restore(sess, checkpoint_path) # ================================================================ # run training steps # ================================================================ step = 0 lowest_loss = 10000.0 validation_total_loss_list = [] while (step < exp_config.max_steps): # ================================================ # batches # ================================================ for batch in iterate_minibatches(images_sd=images_sd_tr, labels_sd=labels_sd_tr, images_td=images_td_tr, batch_size=exp_config.batch_size): x_sd, y_sd, x_td = batch # =========================== # define feed dict for this iteration # =========================== feed_dict = { images_sd_pl: x_sd, labels_sd_pl: y_sd, images_td_pl: x_td, training_pl: True } # ================================================ # update i2l and D successively # ================================================ sess.run(train_i2l_op, feed_dict=feed_dict) sess.run(train_discriminator_op, feed_dict=feed_dict) # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( '============== Updating summary at step %d ' % step) summary_writer.add_summary( sess.run(summary_scalars, feed_dict=feed_dict), step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('============== Training Data Eval:') train_loss_seg, train_dice, train_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_tr, labels_sd_tr, images_td_tr, exp_config.batch_size) # total training loss train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance # =========================== # update tensorboard summary of scalars # =========================== tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error_seg: train_loss_seg, tr_dice: train_dice, tr_error_invariance: train_loss_invariance, tr_error_total: train_total_loss }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: logging.info( '============== Periodically saving checkpoint:') checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('============== Validation Data Eval:') val_loss_seg, val_dice, val_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_vl, labels_sd_vl, images_td_vl, exp_config.batch_size) # total val loss val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance validation_total_loss_list.append(val_total_loss) # =========================== # update tensorboard summary of scalars # =========================== vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error_seg: val_loss_seg, vl_dice: val_dice, vl_error_invariance: val_loss_invariance, vl_error_total: val_total_loss }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # update tensorboard summary of images # =========================== summary_writer.add_summary( sess.run(summary_images, feed_dict={ images_sd_pl: x_sd, labels_sd_pl: y_sd, training_pl: False }), step) summary_writer.flush() # =========================== # save model if the val dice is the best yet # =========================== window_length = 5 if len(validation_total_loss_list) < window_length + 1: expo_moving_avg_loss_value = validation_total_loss_list[ -1] else: expo_moving_avg_loss_value = utils.exponential_moving_average( validation_total_loss_list, window=window_length)[-1] if expo_moving_avg_loss_value < lowest_loss: lowest_loss = val_total_loss lowest_loss_file = os.path.join( log_dir, 'models/lowest_loss.ckpt') saver_lowest_loss.save(sess, lowest_loss_file, global_step=step) logging.info( '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********' % step) # ================================================ # increment step # ================================================ step += 1 # ================================================================ # close tf session # ================================================================ sess.close() return 0
def predict_segmentation(subject_name, image, normalize=True, post_process=False): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + list(exp_config.image_size) + [1], name='images') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l( images_normalized, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # 3d prior # ================================================================ labels_3d_shape = [1] + list(exp_config.image_size_downsampled) # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder predicted_seg_3d_pl = tf.placeholder(tf.uint8, shape=labels_3d_shape, name='true_labels_3d') predicted_seg_1hot_3d_pl = tf.one_hot(predicted_seg_3d_pl, depth=exp_config.nlabels) # denoise the noisy segmentation _, pred_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l( predicted_seg_1hot_3d_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] l2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'l2l_mapper' in var_name: l2l_vars.append(v) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_normalizer = tf.train.Saver(var_list=normalization_vars) saver_l2l = tf.train.Saver(var_list=l2l_vars) # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the prior network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_l2l.restore(sess, checkpoint_path) # ================================================================ # Make predictions for the image at the resolution of the image after pre-processing # ================================================================ mask_predicted = [] mask_predicted_soft = [] img_normalized = [] for b_i in range(0, image.shape[0], 1): X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1) mask_predicted.append( sess.run(predicted_seg, feed_dict={images_pl: X})) mask_predicted_soft.append( sess.run(predicted_seg_softmax, feed_dict={images_pl: X})) img_normalized.append( sess.run(images_normalized, feed_dict={images_pl: X})) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) mask_predicted_soft = np.squeeze( np.array(mask_predicted_soft)).astype(float) img_normalized = np.squeeze(np.array(img_normalized)).astype(float) # ================================================================ # downsample predicted mask and pass it through the DAE # ================================================================ if post_process is True: # downsample mask_predicted_soft mask_predicted_soft_downsampled = rescale(mask_predicted_soft, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') mask_predicted_downsampled = np.argmax( mask_predicted_soft_downsampled, axis=-1) # pass the downsampled prediction through the DAE mask_predicted_denoised = mask_predicted_downsampled for _ in range(exp_config.dae_post_process_runs): feed_dict = { predicted_seg_3d_pl: np.expand_dims(mask_predicted_denoised, axis=0) } y_pred_noisy_denoised_softmax = np.squeeze( sess.run(pred_seg_softmax_3d_noisy_autoencoded_softmax, feed_dict=feed_dict)).astype(np.float16) mask_predicted_denoised = np.argmax( y_pred_noisy_denoised_softmax, axis=-1) # upsample the denoised prediction mask_predicted = rescale(mask_predicted_denoised, [ exp_config.downsampling_factor_x, exp_config.downsampling_factor_y, exp_config.downsampling_factor_z ], order=0, preserve_range=True, multichannel=False, mode='constant').astype(np.uint8) sess.close() return mask_predicted, img_normalized
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) already_created_recursion = False print("ALready created recursion : " + str(already_created_recursion)) init_step = 0 # Load data base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(scribble_file=sys_config.project_root + exp_config.scribble_data, target_folder='/scratch_net/biwirender02/cany/scribble/logdir/heart_dropout_welch_non_exp', percent_full_sup=exp_config.percent_full_sup, scr_ratio=exp_config.length_ratio ) #wrap everything from this point onwards in a try-except to catch keyboard interrupt so #can control h5py closing data try: loaded_previous_recursion = False start_epoch = 0 if continue_run: logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') try: try: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion)) except: print("EXCEPTE GİRDİ") init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1)) loaded_previous_recursion = True logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 # plus 1 b/c otherwise starts with eval start_epoch = int(init_step/(len(base_data['images_train'])/4)) logging.info('Latest step was: %d' % init_step) logging.info('Continuing with epoch: %d' % start_epoch) except: logging.warning('!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...') continue_run = False init_step = 0 start_epoch = 0 logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') if loaded_previous_recursion: logging.info("Data file exists for recursion {} " "but checkpoints only present up to recursion {}".format(recursion, recursion - 1)) logging.info("Likely means postprocessing was terminated") if not already_created_recursion: recursion_data = acdc_data.load_different_recursion(recursion_data, -1) recursion-=1 else: start_epoch = 0 init_step = 0 # load images and validation data images_train = np.array(base_data['images_train']) # if exp_config.use_data_fraction: # num_images = images_train.shape[0] # new_last_index = int(float(num_images)*exp_config.use_data_fraction) # # logging.warning('USING ONLY FRACTION OF DATA!') # logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) # images_train = images_train[0:new_last_index,...] # labels_train = labels_train[0:new_last_index,...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) #logging.info(' - Labels:') #logging.info(labels_train.shape) #logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # with tf.Graph().as_default(): with tf.Session(config = config) as sess: # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) keep_prob = tf.placeholder(tf.float32, shape=[]) tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_placeholder, keep_prob, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm] = model.loss(logits, labels_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. # Only keep two checkpoints, as checkpoints are kept for every recursion # and they can be 300MB + saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge([val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge([train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) # if continue_run: # # Restore session # saver.restore(sess, init_checkpoint_path) # saver.restore(sess,'/scratch_net/biwirender02/cany/scribble/logdir/heart_residual_crf/recursion_0_model_best_dice.ckpt-14199') # # saver.restore(sess,"/scratch_net/biwirender02/cany/scribble/logdir/"+str(exp_config.experiment_name)+ "/recursion_0_model_best_dice.ckpt-26699") ## recursion=0 random_walked = np.array(recursion_data['random_walked']) recursion_data = predict_next_gt(data2=recursion_data, images_train=images_train, images_placeholder=images_placeholder, training_time_placeholder=training_time_placeholder, keep_prob=keep_prob, logits=logits, sess=sess, random_walked=random_walked, ) except Exception: raise
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imtr, gttr = data_hcp.load_data(sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', 1, exp_config.train_size + 1) imvl, gtvl = data_hcp.load_data( sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', exp_config.train_size + 1, exp_config.train_size + exp_config.val_size + 1) logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # Define the segmentation network here # ================================================================ mask = ~(imtr == 0).all(axis=(1, 2)) imtr = imtr[mask] gttr = gttr[mask] mask = ~(imvl == 0).all(axis=(1, 2)) imvl = imvl[mask] gtvl = gtvl[mask] # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') # Placeholders for the images and labels image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_pl = tf.compat.v1.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_tensor_shape_segmentation = [exp_config.batch_size] + list( exp_config.image_size) labels_tensor_shape_self_supervised = [exp_config.batch_size] + [ exp_config.num_rotation_values ] labels_segmentation_pl = tf.compat.v1.placeholder( tf.uint8, shape=labels_tensor_shape_segmentation, name='labels_segmentation') labels_self_supervised_pl = tf.compat.v1.placeholder( tf.float16, shape=labels_tensor_shape_self_supervised, name='labels_self_supervised') # Placeholders for the learning rate and to indicate whether the model is being trained or tested learning_rate_pl = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.compat.v1.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # Define the segmentation network here # ================================================================ logits_segmentation = exp_config.model_handle_segmentation( images_pl, image_tensor_shape, training_pl, exp_config.nlabels) # ================================================================ # determine trainable variables # ================================================================ segmentation_vars = [] self_supervised_vars = [] test_time_opt_vars = [] for v in tf.compat.v1.trainable_variables(): var_name = v.name if 'rotation' in var_name: self_supervised_vars.append(v) elif 'segmentation' in var_name: segmentation_vars.append(v) elif 'normalization' in var_name: test_time_opt_vars.append(v) if exp_config.debug is True: logging.info('================================') logging.info('List of trainable variables in the graph:') for v in tf.compat.v1.trainable_variables(): logging.info(v.name) logging.info('================================') logging.info('List of all segmentation variables:') for v in segmentation_vars: logging.info(v.name) logging.info('================================') logging.info('List of all self-supervised variables:') for v in self_supervised_vars: logging.info(v.name) logging.info('================================') logging.info('List of all test time variables:') for v in test_time_opt_vars: logging.info(v.name) # ================================================================ # Add ops for calculation of the training loss # ================================================================ loss_segmentation = model.loss(logits_segmentation, labels_segmentation_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) tf.compat.v1.summary.scalar('loss_segmentation', loss_segmentation) # ================================================================ # Add optimization ops. # ================================================================ train_op_train_time = model.training_step( loss_segmentation, segmentation_vars, exp_config.optimizer_handle_training, learning_rate_pl) # ================================================================ # Add ops for model evaluation # ================================================================ eval_loss_segmentation = model.evaluation( logits_segmentation, labels_segmentation_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) # ================================================================ # Build the summary Tensor based on the TF collection of Summaries. # ================================================================ summary = tf.compat.v1.summary.merge_all() # ================================================================ # Add init ops # ================================================================ init_g = tf.compat.v1.global_variables_initializer() init_l = tf.compat.v1.local_variables_initializer() # ================================================================ # Find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.compat.v1.report_uninitialized_variables() # ================================================================ # create savers for each domain # ================================================================ max_to_keep = 15 saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep) saver_best_da = tf.compat.v1.train.Saver() # ================================================================ # Create session # ================================================================ config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.compat.v1.Session(config=config) # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the training errors # ================================================================ tr_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='tr_error_seg') tr_error_summary_seg = tf.compat.v1.summary.scalar( 'training/loss_seg', tr_error_seg) tr_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='tr_dice_seg') tr_dice_summary_seg = tf.compat.v1.summary.scalar( 'training/dice_seg', tr_dice_seg) tr_summary_seg = tf.compat.v1.summary.merge( [tr_error_summary_seg, tr_dice_summary_seg]) # ================================================================ # summaries of the validation errors # ================================================================ vl_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='vl_error_seg') vl_error_summary_seg = tf.compat.v1.summary.scalar( 'validation/loss_seg', vl_error_seg) vl_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='vl_dice_seg') vl_dice_summary_seg = tf.compat.v1.summary.scalar( 'validation/dice_seg', vl_dice_seg) vl_summary_seg = tf.compat.v1.summary.merge( [vl_error_summary_seg, vl_dice_summary_seg]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.compat.v1.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_g) sess.run(init_l) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: logging.info(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_da = 0 # ================================================================ # run training epochs # ================================================================ for epoch in range(exp_config.max_epochs): logging.info( '============================================================') logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size): curr_lr = exp_config.learning_rate start_time = time.time() x, y_seg, y_ss = batch if step == 0: x_s = x y_seg_s = y_seg y_ss_s = y_ss # =========================== # avoid incomplete batches # =========================== if y_seg.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_segmentation_pl: y_seg, labels_self_supervised_pl: y_ss, learning_rate_pl: curr_lr, training_pl: True } # =========================== # update vars # =========================== _, loss_value = sess.run( [train_op_train_time, loss_segmentation], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.2f (%.3f sec for the last step)' % (step + 1, loss_value, duration)) # =========================== # print values of a parameter (to debug) # =========================== if exp_config.debug is True: var_value = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES )[0].eval( session=sess ) # can add name if you only want parameters with a specific name logging.info('value of one of the parameters %f' % var_value[0, 0, 0, 0]) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss_seg, train_dice_seg ] = do_eval(sess, eval_loss_segmentation, images_pl, labels_segmentation_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary_seg, feed_dict={ tr_error_seg: train_loss_seg, tr_dice_seg: train_dice_seg }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically # =========================== if step % exp_config.val_eval_frequency == 0: # =========================== # Evaluate against the validation set of each domain # =========================== logging.info('Validation Data Eval:') [val_loss_seg, val_dice_seg] = do_eval(sess, eval_loss_segmentation, images_pl, labels_segmentation_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary_seg, feed_dict={ vl_error_seg: val_loss_seg, vl_dice_seg: val_dice_seg }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice/accuracy is the best yet # =========================== if val_dice_seg > best_da: best_da = val_dice_seg best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_da.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice_seg) step += 1 sess.close()
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=1040, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] if exp_config.train_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'IXI': logging.info('Reading IXI images...') logging.info('Data root directory: ' + sys_config.orig_data_root_ixi) data_brain_train = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=0, idx_end=12, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=12, idx_end=17, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module will be adapted for each test image # ================================================================ images_normalized, _ = model.normalize(images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, _, _ = model.predict_i2l(images_normalized, exp_config, training_pl=training_pl) print('shape of inputs: ', images_pl.shape) # (batch_size, 256, 256, 1) print('shape of logits: ', logits.shape) # (batch_size, 256, 256, nlabels) # ================================================================ # create a list of all vars that must be optimized wrt # ================================================================ i2l_vars = [] for v in tf.trainable_variables(): i2l_vars.append(v) # ================================================================ # add ops for calculation of the supervised training loss # ================================================================ loss_op = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('loss', loss_op) # ================================================================ # add optimization ops. # Create different ops according to the variables that must be trained # ================================================================ print('creating training op...') train_op = model.training_step(loss_op, i2l_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars=True) # ================================================================ # add ops for model evaluation # ================================================================ print('creating eval op...') eval_loss = model.evaluation_i2l(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ print('creating summary op...') summary = tf.summary.merge_all() # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create saver # ================================================================ saver = tf.train.Saver(max_to_keep=10) saver_best_dice = tf.train.Saver(max_to_keep=3) # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of all variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of all variables:') for v in tf.trainable_variables(): print(v.name) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_dice = 0 # ================================================================ # run training epochs # ================================================================ while (step < exp_config.max_steps): if step % 1000 is 0: logging.info( '============================================================' ) logging.info('step %d' % step) # ================================================ # batches # ================================================ for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size, train_or_eval='train'): curr_lr = exp_config.learning_rate start_time = time.time() x, y = batch # =========================== # avoid incomplete batches # =========================== if y.shape[0] < exp_config.batch_size: step += 1 continue # =========================== # create the feed dict for this training iteration # =========================== feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } # =========================== # opt step # =========================== _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.3f (%.3f sec for the last step)' % (step + 1, loss, duration)) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') train_loss, train_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('Validation Data Eval:') val_loss, val_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()
def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, recursion=None): base_data = h5py.File( '/scratch_net/biwirender02/cany/scribble/scribble_data/prostate_divided_normalized.h5', 'r') images = np.array(base_data['images_test']) labels = np.array(base_data['masks_test']) slices_val = np.array(base_data['slices_test']) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) batch_size = 4 image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') logits_op = model.inference(images_placeholder, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) saver = tf.train.Saver() init = tf.global_variables_initializer() dices = [] config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True with tf.Session(config=config) as sess: sess.run(init) if recursion is None: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'model_best_dice.ckpt') else: try: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model_best_dice.ckpt'.format(recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) for patient in range(len(slices_val)): cur_slice = np.copy(np.uint(sum(slices_val[:patient]))) whole_label = np.copy( labels[np.uint(cur_slice):np.uint(cur_slice + slices_val[patient]), :, :]) res = np.zeros((np.uint(slices_val[patient]), 320, 320)) for sli in range(int(slices_val[patient])): x = np.zeros((4, 320, 320, 1)) x[0, :, :, :] = np.copy( np.expand_dims(images[int(sli + cur_slice), :, :], axis=3)) # y=labels[int(sli+cur_slice),:,:] softmax = tf.nn.softmax(tf.slice(logits_op, [0, 0, 0, 0], [1, -1, -1, -1]), dim=-1) mask_op = tf.arg_max(softmax, dimension=-1) # was 3 feed_dict = { images_placeholder: x, training_time_placeholder: False } mask = sess.run(mask_op, feed_dict=feed_dict) res[sli, :, :] = np.copy(mask) temp_dices = [] temp_preds = [] temp_labels = [] for c in [1, 2]: # Copy the gt image to not alterate the input gt_c_i = np.copy(whole_label) gt_c_i[gt_c_i != c] = 0 # Copy the pred image to not alterate the input pred_c_i = np.copy(res) pred_c_i[pred_c_i != c] = 0 # Clip the value to compute the volumes gt_c_i = np.clip(gt_c_i, 0, 1) pred_c_i = np.clip(pred_c_i, 0, 1) temp_preds.append(pred_c_i) temp_labels.append(gt_c_i) # Compute the Dice dice = my_dice(gt_c_i, pred_c_i) # dice = dc(gt_c_i, pred_c_i) temp_dices.append(dice) dices.append(temp_dices) logging.info("Dice for patient : " + str(temp_dices[0]) + " and " + str(temp_dices[1])) if patient == 0: all_preds = np.asarray(temp_preds) all_labels = np.asarray(temp_labels) else: all_preds = np.concatenate( [all_preds, np.asarray(temp_preds)], axis=1) all_labels = np.concatenate( [all_labels, np.asarray(temp_labels)], axis=1) # This is the same for 2D and 3D again # if do_postprocessing: # prediction_arr = image_utils.keep_largest_connected_components(prediction_arr) # # Save predicted mask out_file_name = os.path.join(output_folder, 'prediction', 'patient' + str(patient) + '.nii.gz') # print(str(res.shape)) # print(str(whole_label.shape)) # # logging.info('saving to: %s' % out_file_name) save_nii(out_file_name, np.asarray(res)) # Save GT image gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + str(patient) + '.nii.gz') # logging.info('saving to: %s' % gt_file_name) save_nii(gt_file_name, np.uint8(np.asarray(whole_label))) # # Save difference mask between predictions and ground truth # difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0]) # difference_mask = np.asarray(difference_mask, dtype=np.uint8) # diff_file_name = os.path.join(output_folder, # 'difference', # 'patient' + str(k) + '.nii.gz') # logging.info('saving to: %s' % diff_file_name) # save_nii(diff_file_name, difference_mask) # Save image data to the same folder for convenience image_file_name = os.path.join( output_folder, 'image', 'patient' + str(patient) + '.nii.gz') # logging.info('saving to: %s' % image_file_name) save_nii( image_file_name, np.uint8(255 * np.asarray(images[ np.uint(cur_slice):np.uint(cur_slice + slices_val[patient]), :, :]))) return 0
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 # Load data base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles( scribble_file=sys_config.project_root + exp_config.scribble_data, target_folder=log_dir, percent_full_sup=exp_config.percent_full_sup, scr_ratio=exp_config.length_ratio) #wrap everything from this point onwards in a try-except to catch keyboard interrupt so #can control h5py closing data try: loaded_previous_recursion = False start_epoch = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) except: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1)) loaded_previous_recursion = True logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval start_epoch = int( init_step / (len(base_data['images_train']) / exp_config.batch_size)) logging.info('Latest step was: %d' % init_step) logging.info('Continuing with epoch: %d' % start_epoch) except: logging.warning( '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 start_epoch = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) if loaded_previous_recursion: logging.info( "Data file exists for recursion {} " "but checkpoints only present up to recursion {}".format( recursion, recursion - 1)) logging.info("Likely means postprocessing was terminated") recursion_data = acdc_data.load_different_recursion( recursion_data, -1) recursion -= 1 # load images and validation data images_train = np.array(base_data['images_train']) scribbles_train = np.array(base_data['scribbles_train']) images_val = np.array(base_data['images_test']) labels_val = np.array(base_data['masks_test']) # if exp_config.use_data_fraction: # num_images = images_train.shape[0] # new_last_index = int(float(num_images)*exp_config.use_data_fraction) # # logging.warning('USING ONLY FRACTION OF DATA!') # logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) # images_train = images_train[0:new_last_index,...] # labels_train = labels_train[0:new_last_index,...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) #logging.info(' - Labels:') #logging.info(labels_train.shape) #logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_placeholder, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm ] = model.loss(logits, labels_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay ) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if exp_config.momentum is not None: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_placeholder, momentum=exp_config.momentum) else: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_placeholder) # Add the Op to compare the logits to the labels during evaluation. # eval_loss = model.evaluation(logits, # labels_placeholder, # images_placeholder, # nlabels=exp_config.nlabels, # loss_type=exp_config.loss_type, # weak_supervision=True, # cnn_threshold=exp_config.cnn_threshold, # include_bg=True) eval_val_loss = model.evaluation( logits, labels_placeholder, images_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weak_supervision=True, cnn_threshold=exp_config.cnn_threshold, include_bg=False) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. # Only keep two checkpoints, as checkpoints are kept for every recursion # and they can be 300MB + saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge( [val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) # Restore session # crf_weights = [] # for v in tf.all_variables(): # # if v.name[0:4]=='bila': # print(str(v)) # crf_weights.append(v.name) # elif v.name[0:4] =='spat': # print(str(v)) # crf_weights.append(v.name) # elif v.name[0:4] =='comp': # print(str(v)) # crf_weights.append(v.name) # restore_var = [v for v in tf.all_variables() if v.name not in crf_weights] # # load_saver = tf.train.Saver(var_list=restore_var) # load_saver.restore(sess, '/scratch_net/biwirender02/cany/basil/logdir/unet2D_ws_spot_blur/recursion_0_model.ckpt-5699') if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 logging.info('RECURSION {0}'.format(recursion)) # random walk - if it already has been random walked it won't redo recursion_data = acdc_data.random_walk_epoch( recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) #get ground truths labels_train = np.array(recursion_data['random_walked']) for epoch in range(start_epoch, exp_config.max_epochs): if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0) \ or loaded_previous_recursion: loaded_previous_recursion = False #Have reached end of recursion recursion_data = predict_next_gt( data=recursion_data, images_train=images_train, images_placeholder=images_placeholder, training_time_placeholder=training_time_placeholder, logits=logits, sess=sess) recursion_data = postprocess_gt( data=recursion_data, images_train=images_train, scribbles_train=scribbles_train) recursion += 1 # random walk - if it already has been random walked it won't redo recursion_data = acdc_data.random_walk_epoch( recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) #get ground truths labels_train = np.array(recursion_data['random_walked']) #reinitialise savers - otherwise, no checkpoints will be saved for each recursion saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) logging.info( 'Epoch {0} ({1} of {2} epochs for recursion {3})'.format( epoch, 1 + epoch % exp_config.epochs_per_recursion, exp_config.epochs_per_recursion, recursion)) # for batch in iterate_minibatches(images_train, # labels_train, # batch_size=exp_config.batch_size, # augment_batch=exp_config.augment_batch): # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the # training speed. However, be aware that currently an exception inside this loop may not be caught. # The batch generator may just continue running silently without warning even though the code has # crashed. for batch in BackgroundGenerator( iterate_minibatches( images_train, labels_train, batch_size=exp_config.batch_size, augment_batch=exp_config.augment_batch)): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches if y.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_placeholder: x, labels_placeholder: y, learning_rate_placeholder: curr_lr, training_time_placeholder: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.6f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # if (step + 1) % exp_config.train_eval_frequency == 0: # # logging.info('Training Data Eval:') # [train_loss, train_dice] = do_eval(sess, # eval_loss, # images_placeholder, # labels_placeholder, # training_time_placeholder, # images_train, # labels_train, # exp_config.batch_size) # # train_summary_msg = sess.run(train_summary, feed_dict={train_error_: train_loss, # train_dice_: train_dice} # ) # summary_writer.add_summary(train_summary_msg, step) # # loss_history.append(train_loss) # if len(loss_history) > 5: # loss_history.pop(0) # loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 # # logging.info('loss gradient is currently %f' % loss_gradient) # # if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: # logging.warning('Reducing learning rate!') # curr_lr /= 10.0 # logging.info('Learning rate changed to: %f' % curr_lr) # # # reset loss history to give the optimisation some time to start decreasing again # loss_gradient = np.inf # loss_history = [] # # if train_loss <= last_train: # best_train: # logging.info('Decrease in training error!') # else: # logging.info('No improvement in training error for %d steps' % no_improvement_counter) # # last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice ] = do_eval(sess, eval_val_loss, images_placeholder, labels_placeholder, training_time_placeholder, images_val, labels_val, exp_config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join( log_dir, 'recursion_{}_model_best_dice.ckpt'.format( recursion)) saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - {} - ' 'Saving recursion_{}_model_best_dice.ckpt'. format(val_dice, recursion)) if val_loss < best_val: best_val = val_loss best_file = os.path.join( log_dir, 'recursion_{}_model_best_xent.ckpt'.format( recursion)) saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - {} - ' 'Saving recursion_{}_model_best_xent.ckpt'. format(val_loss, recursion)) step += 1 except Exception: raise
def generate_adversarial_examples(input_folder, output_path, model_path, attack, attack_args, exp_config, add_gaussian=False): nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = exp_config.nlabels image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] mask_tensor_shape = [batch_size] + list(exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') logits_pl = model.inference(images_pl, exp_config=exp_config, training=tf.constant(False, dtype=tf.bool)) eval_loss = model.evaluation(logits_pl, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) data = acdc_data.load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=exp_config.data_mode, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False, split_test_train=True) images = data['images_test'][:20] labels = data['masks_test'][:20] print("Num images train {} test {}".format(len(data['images_train']), len(images))) saver = tf.train.Saver() init = tf.global_variables_initializer() baseline_closs = 0.0 baseline_cdice = 0.0 attack_closs = 0.0 attack_cdice = 0.0 l2_diff_sum = 0.0 ln_diff_sum = 0.0 ln_diff = 0.0 l2_diff = 0.0 batches = 0 result_dict = [] with tf.Session() as sess: results = [] sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) for batch in BackgroundGenerator( train.iterate_minibatches(images, labels, batch_size)): x, y = batch batches += 1 if batches != 9: continue non_adv_mask_out = sess.run( [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)], feed_dict={images_pl: x}) if attack == 'fgsm': adv_x = adv_attack.fgsm_run(x, y, images_pl, labels_pl, logits_pl, exp_config, sess, attack_args) elif attack == 'pgd': adv_x = adv_attack.pgd(x, y, images_pl, labels_pl, logits_pl, exp_config, sess, attack_args) elif attack == 'spgd': adv_x = adv_attack.pgd_conv(x, y, images_pl, labels_pl, logits_pl, exp_config, sess, **attack_args) else: raise NotImplementedError adv_x = [adv_x] if add_gaussian: print('adding gaussian noise') adv_x = adv_attack.add_gaussian_noise( x, adv_x[0], sess, eps=attack_args['eps'], sizes=attack_args['sizes'], weights=attack_args['weights']) for i in range(len(adv_x)): l2_diff = np.average( np.squeeze(np.linalg.norm(adv_x[i] - x, axis=(1, 2)))) ln_diff = np.average( np.squeeze( np.linalg.norm(adv_x[i] - x, axis=(1, 2), ord=np.inf))) l2_diff_sum += l2_diff ln_diff_sum += ln_diff print(l2_diff, l2_diff) adv_mask_out = sess.run( [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)], feed_dict={images_pl: adv_x[i]}) closs, cdice = sess.run(eval_loss, feed_dict={ images_pl: x, labels_pl: y }) baseline_closs = closs + baseline_closs baseline_cdice = cdice + baseline_cdice adv_closs, adv_cdice = sess.run(eval_loss, feed_dict={ images_pl: adv_x[i], labels_pl: y }) attack_closs = adv_closs + attack_closs attack_cdice = adv_cdice + attack_cdice partial_result = dict({ 'attack': attack, 'attack_args': { k: attack_args[k] for k in ['eps', 'step_alpha', 'epochs'] }, # 'baseline_closs': closs, 'baseline_cdice': cdice, 'attack_closs': adv_closs, 'attack_cdice': adv_cdice, 'attack_l2_diff': l2_diff, 'attack_ln_diff': ln_diff }) jsonString = json.dumps(str(partial_result)) #results.append(copy.deepcopy(result_dict)) with open( "eval_results/{}-{}-{}-{}-metrics.json".format( attack, add_gaussian, batches, i), "w") as jsonFile: jsonFile.write(jsonString) image_gt = "eval_results/ground-truth-{}-{}-{}-{}.pdf".format( attack, add_gaussian, batches, i) plt.imshow(np.squeeze(x), cmap='gray') plt.imshow(np.squeeze(y), cmap='viridis', alpha=0.7) plt.axis('off') plt.tight_layout() plt.savefig(image_gt, format='pdf') plt.clf() image_benign = "eval_results/benign-{}-{}-{}-{}.pdf".format( attack, add_gaussian, batches, i) plt.imshow(np.squeeze(x), cmap='gray') plt.imshow(np.squeeze(non_adv_mask_out), cmap='viridis', alpha=0.7) plt.axis('off') plt.tight_layout() plt.savefig(image_benign, format='pdf') plt.clf() image_adv = "eval_results/adversarial-{}-{}-{}-{}.pdf".format( attack, add_gaussian, batches, i) plt.imshow(np.squeeze(adv_x[i]), cmap='gray') plt.imshow(np.squeeze(adv_mask_out), cmap='viridis', alpha=0.7) plt.axis('off') plt.tight_layout() plt.savefig(image_adv, format='pdf') plt.clf() plt.imshow(np.squeeze(adv_x[i]), cmap='gray') image_adv_input = "eval_results/adv-input-{}-{}-{}-{}.pdf".format( attack, add_gaussian, batches, i) plt.tight_layout() plt.axis('off') plt.savefig(image_adv_input, format='pdf') plt.clf() plt.imshow(np.squeeze(x), cmap='gray') image_adv_input = "eval_results/benign-input-{}-{}-{}-{}.pdf".format( attack, add_gaussian, batches, i) plt.axis('off') plt.tight_layout() plt.savefig(image_adv_input, format='pdf') plt.clf() print(attack_closs, attack_cdice, l2_diff, ln_diff) print("Evaluation results") print("{} Attack Params {}".format(attack, attack_args)) print("Baseline metrics: Avg loss {}, Avg DICE Score {} ".format( baseline_closs / (batches * len(adv_x)), baseline_cdice / (batches * len(adv_x)))) print( "{} Attack effectiveness: Avg loss {}, Avg DICE Score {} ".format( attack, attack_closs / (batches * len(adv_x)), attack_cdice / (batches * len(adv_x)))) print( "{} Attack visibility: Avg l2-norm diff {} Avg l-inf-norm diff {}". format(attack, l2_diff_sum / (batches * len(adv_x)), ln_diff_sum / (batches * len(adv_x)))) result_dict = dict({ 'attack': attack, 'attack_args': {k: attack_args[k] for k in ['eps', 'step_alpha', 'epochs']}, # 'baseline_closs_avg': baseline_closs / batches, 'baseline_cdice_avg': baseline_cdice / batches, 'attack_closs_avg': attack_closs / batches, 'attack_cdice_avg': attack_cdice / batches, 'attack_l2_diff': l2_diff_sum / batches, 'attack_ln_diff': ln_diff_sum / batches }) results.append(copy.deepcopy(result_dict)) print(results) jsonString = json.dumps(results) with open("eval_results/{}-results.json".format(attack), "w") as jsonFile: jsonFile.write(jsonString)
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None): nx, ny = config.image_size[:2] batch_size = 1 num_channels = config.nlabels image_tensor_shape = [batch_size] + list(config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') # According to the experiment config, pick a model and predict the output # TODO: Implement majority voting using 3 models. mask_pl, softmax_pl = model.predict(images_pl, config) saver = tf.train.Saver() init = tf.global_variables_initializer() evaluate_test_set = not gt_exists with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob(os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info(' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info(' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() #img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) #img = image_utils.normalize_image(img) print('img') print(img.shape) print(img.dtype) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() if config.data_mode == '2D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / config.target_resolution[0], pixel_size[1] / config.target_resolution[1]) print('pixel_size', pixel_size) print('scale_vector', scale_vector) predictions = [] mask_arr = [] img_arr = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:,:,zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') print('slice_img', slice_img.shape) print('slice_rescaled', slice_rescaled.shape) slice_mask = np.squeeze(mask[:, :, zz]) mask_rescaled = transform.rescale(slice_mask, scale_vector, order=0, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') slice_cropped = acdc_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny) print('slice_cropped', slice_cropped.shape) mask_cropped = acdc_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny) slice_cropped = np.float32(slice_cropped) mask_cropped = np.asarray(mask_cropped, dtype=np.uint8) x = image_utils.reshape_2Dimage_to_tensor(slice_cropped) y = image_utils.reshape_2Dimage_to_tensor(mask_cropped) # GET PREDICTION feed_dict = { images_pl: x, } mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) prediction_cropped = np.squeeze(logits_out[0,...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros((nx,ny,num_channels)) slice_predictions = prediction_cropped # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize(slice_predictions, (nx, ny, num_channels), order=1, preserve_range=True, anti_aliasing=True, mode='constant') # prediction = transform.resize(slice_predictions, # (mask.shape[0], mask.shape[1], num_channels), # order=1, # preserve_range=True, # mode='constant') prediction = np.uint8(np.argmax(prediction, axis=-1)) predictions.append(prediction) mask_arr.append(np.squeeze(y)) img_arr.append(np.squeeze(x)) prediction_arr = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0)) mask_arrs = np.transpose(np.asarray(mask_arr, dtype=np.uint8), (1,2,0)) img_arrs = np.transpose(np.asarray(img_arr, dtype=np.float32), (1,2,0)) # This is the same for 2D and 3D again if do_postprocessing: prediction_arr = image_utils.keep_largest_connected_components(prediction_arr) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError('Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join(output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join(output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask_arrs, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where(np.abs(prediction_arr-mask_arrs) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) # for zz in range(difference_mask.shape[2]): # # fig = plt.figure() # ax1 = fig.add_subplot(221) # ax1.set_axis_off() # ax1.imshow(img_arrs[:,:,zz]) # ax2 = fig.add_subplot(222) # ax2.set_axis_off() # ax2.imshow(mask_arrs[:,:,zz]) # ax3 = fig.add_subplot(223) # ax3.set_axis_off() # ax3.imshow(prediction_arr[:,:,zz]) # ax1.title.set_text('a') # ax2.title.set_text('b') # ax3.title.set_text('c') # ax4 = fig.add_subplot(224) # ax4.set_axis_off() # ax4.imshow(difference_mask[:,:,zz], cmap=plt.cm.gnuplot) # ax1.title.set_text('a') # ax2.title.set_text('b') # ax3.title.set_text('c') # ax4.title.set_text('d') # plt.gray() # plt.show() for zz in range(difference_mask.shape[2]): plt.imshow(img_arrs[:,:,zz]) plt.gray() plt.axis('off') plt.show() plt.imshow(mask_arrs[:,:,zz]) plt.gray() plt.axis('off') plt.show() plt.imshow(prediction_arr[:,:,zz]) plt.gray() plt.axis('off') plt.show() print('...') diff_file_name = os.path.join(output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time/total_volumes)) return init_iteration
def main(exp_config): # Load data data = acdc_data.load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=exp_config.data_mode, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False ) batch_size = 1 image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) while True: ind = np.random.randint(data['images_test'].shape[0]) x = data['images_test'][ind,...] y = data['masks_test'][ind,...] x = image_utils.reshape_2Dimage_to_tensor(x) y = image_utils.reshape_2Dimage_to_tensor(y) feed_dict = { images_pl: x, } mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) fig = plt.figure() ax1 = fig.add_subplot(241) ax1.imshow(np.squeeze(x), cmap='gray') ax2 = fig.add_subplot(242) ax2.imshow(np.squeeze(y)) ax3 = fig.add_subplot(243) ax3.imshow(np.squeeze(mask_out)) ax5 = fig.add_subplot(245) ax5.imshow(np.squeeze(softmax_out[...,0])) ax6 = fig.add_subplot(246) ax6.imshow(np.squeeze(softmax_out[...,1])) ax7 = fig.add_subplot(247) ax7.imshow(np.squeeze(softmax_out[...,2])) ax8 = fig.add_subplot(248) ax8.imshow(np.squeeze(softmax_out[...,3])) plt.show() data.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 # Load data base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(scribble_file=sys_config.project_root + exp_config.scribble_data, target_folder=log_dir, percent_full_sup=exp_config.percent_full_sup, scr_ratio=exp_config.length_ratio ) #wrap everything from this point onwards in a try-except to catch keyboard interrupt so #can control h5py closing data try: loaded_previous_recursion = False start_epoch = 0 if continue_run: logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') try: try: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion)) except: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1)) loaded_previous_recursion = True logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 # plus 1 b/c otherwise starts with eval start_epoch = int(init_step/(len(base_data['images_train'])/4)) logging.info('Latest step was: %d' % init_step) logging.info('Continuing with epoch: %d' % start_epoch) except: logging.warning('!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...') continue_run = False init_step = 0 start_epoch = 0 logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') if loaded_previous_recursion: logging.info("Data file exists for recursion {} " "but checkpoints only present up to recursion {}".format(recursion, recursion - 1)) logging.info("Likely means postprocessing was terminated") recursion_data = acdc_data.load_different_recursion(recursion_data, -1) recursion-=1 # load images and validation data images_train = np.array(base_data['images_train']) scribbles_train = np.array(base_data['scribbles_train']) images_val = np.array(base_data['images_test']) labels_val = np.array(base_data['masks_test']) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # with tf.Graph().as_default(): with tf.Session(config = config) as sess: # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list(exp_config.image_size) images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_placeholder, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Run the Op to initialize the variables. sess.run(init) saver = tf.train.Saver(max_to_keep=2) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 logging.info('RECURSION {0}'.format(recursion)) # # random walk - if it already has been random walked it won't redo # recursion_data = acdc_data.random_walk_epoch(recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) #Have reached end of recursion recursion_data = predict_next_gt(data=recursion_data, images_train=images_train, images_placeholder=images_placeholder, training_time_placeholder=training_time_placeholder, logits=logits, sess=sess) recursion_data = postprocess_gt(data=recursion_data, images_train=images_train, scribbles_train=scribbles_train) recursion += 1 # random walk - if it already has been random walked it won't redo recursion_data = acdc_data.random_walk_epoch(recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) except Exception: raise
def predict_segmentation(image): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + [exp_config.image_size[0]] + [exp_config.image_size[1]] + [1], name='images') # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, softmax, preds = model.predict_i2l(images_pl, exp_config, training_pl=tf.constant( False, dtype=tf.bool)) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver() # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # predict segmentation # ================================================================ X = np.expand_dims(np.expand_dims(image, axis=-1), axis=0) mask_predicted = sess.run(preds, feed_dict={images_pl: X}) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) sess.close() return mask_predicted
def main(csv_path=None, batch_size=None): #get number of recursions num_recursions = acdc_data.most_recent_recursion(model_path) #get data base_data = h5py.File(os.path.join(model_path, 'base_data.hdf5'), 'r') masks_gt = np.array(base_data['masks_train']) images = np.array(base_data['images_train']) base_data.close() r_data = h5py.File(os.path.join(model_path, 'recursion_0_data.hdf5'), 'r') scribbles = np.array(r_data['predicted']) r_data.close() with h5py.File(model_path + "/recursion_evaluation/assessment.hdf5") as output_data: #FULLY RANDOM WALKED if not 'random_walk_mask' in output_data: output_data.create_dataset(name='random_walk_mask', dtype=np.uint8, shape=images.shape) output_data['random_walk_mask'].attrs.create('processed_to', dtype=np.uint16, data=0) initialise_dice_attrs(output_data, 'random_walk_mask', exp_config.nlabels) #process random walk processed_to = output_data['random_walk_mask'].attrs.get( 'processed_to') + 1 for scr_idx in range(processed_to, len(images), exp_config.batch_size): ind = list( range(scr_idx, min(scr_idx + exp_config.batch_size, len(images)))) logging.info( 'Preparing random walker segmentation for slices {} to {}'. format(ind[0], ind[-1])) output_data['random_walk_mask'][ind, ...] = scribbles[ind, ...] # output_data['random_walk_mask'][ind, ...] = rw(images[ind, ...], # scribbles[ind, ...], # beta=exp_config.rw_beta, # threshold=0) # output_data['random_walk_mask'].attrs.modify('processed_to', ind[-1] + 1) #calculate dice _, dices, mean_dice, stdevs = calculate_dices( np.array(output_data['random_walk_mask']), masks_gt, exp_config.nlabels, path=csv_path, filename="fully_random_walked") output_data['random_walk_mask'].attrs.modify('mean_dice', mean_dice) output_data['random_walk_mask'].attrs.modify('dices', dices) output_data['random_walk_mask'].attrs.modify('stdevs', stdevs) #iterate through recursions #recursion 0 slightly different as it has no 'output from previous epoch' set r_data = h5py.File(os.path.join(model_path, 'recursion_0_data.hdf5'), 'r') dset_names = np.array([[ 'recursion_{0}/processed_{0}', 'recursion_{0}/random_walked_{0}', 'recursion_{0}/prediction_{0}' ], ['postprocessed', 'random_walked', 'predicted' ], ['processed r{0} seg', 'random walked r{0} seg', 'r{0} seg']]) for recursion in range(0, num_recursions + 1): for i in range(3): dset_name = dset_names[0, i].format(recursion) #once it gets to the prediction it must open the next dataset file if i == 2: r_data.close() if recursion == num_recursions: break r_data = h5py.File( os.path.join( model_path, 'recursion_{}_data.hdf5'.format(recursion + 1)), 'r') if not dset_name in output_data: logging.info("Creating dataset {}".format(dset_name)) logging.info("Getting data from {}[{}]".format( os.path.basename(r_data.filename).split('.')[0], dset_names[1, i])) output_data.create_dataset(dset_name, dtype=np.uint8, shape=images.shape, data=r_data[dset_names[1, i]]) initialise_dice_attrs(output_data, dset_name, exp_config.nlabels) else: logging.info("Getting data from {}[{}]".format( os.path.basename(r_data.filename).split('.')[0], dset_names[1, i])) output_data[dset_name][:] = np.array(r_data[dset_names[1, i]]) #calculate dices indices, dices, mean_dice, stdevs = calculate_dices( np.array(output_data[dset_name]), masks_gt, exp_config.nlabels, path=csv_path, filename=dset_names[2, i].format(recursion)) output_data[dset_name].attrs.modify('mean_dice', mean_dice) output_data[dset_name].attrs.modify('dices', dices) output_data[dset_name].attrs.modify('stdevs', stdevs) #for last recursion, need to use prediction from network final_dset = 'recursion_{0}/prediction_{0}'.format(recursion) if not final_dset in output_data: output_data.create_dataset(final_dset, dtype=np.uint8, shape=images.shape, data=masks_gt) output_data[final_dset].attrs.create(name='processed_to', data=np.array((0, )), shape=(1, ), dtype=np.uint16) output_data[final_dset].attrs.create(name='processed', data=np.array(False), shape=(), dtype=np.bool) initialise_dice_attrs(output_data, final_dset, exp_config.nlabels) try: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model*.ckpt'.format(recursion)) skip_final_recursion = False except: skip_final_recursion = True if not skip_final_recursion: image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels) saver = tf.train.Saver() init = tf.global_variables_initializer() training_time_placeholder = tf.placeholder(tf.bool, shape=[]) with tf.Session() as sess: sess.run(init) try: saver.restore(sess, checkpoint_path) epoch = int( checkpoint_path.split('/')[-1].split('-')[-1]) + 1 epoch = int(epoch / (len(images) / exp_config.batch_size)) epoch = epoch % exp_config.epochs_per_recursion skip_final_recursion = False except: logging.info( "Failed to load checkpoint for recursion {}".format( recursion)) skip_final_recursion = True if not skip_final_recursion: scr_max = len(images) processed_to = output_data[final_dset].attrs.get( 'processed_to') print(processed_to) processed_to = 0 if processed_to is None else processed_to processed_to = np.squeeze(processed_to) for scr_idx in range(processed_to, len(images), exp_config.batch_size): if scr_idx + exp_config.batch_size > scr_max: # At the end of the dataset ind = list( range(scr_max - exp_config.batch_size, scr_max)) else: ind = list( range(scr_idx, scr_idx + exp_config.batch_size)) logging.info( "Segmenting images using weights from final " "recursion ({}) for slices {} to {}".format( recursion, ind[0], ind[-1])) feed_dict = { images_pl: np.expand_dims(images[ind, ...], -1), training_time_placeholder: False } output_data[final_dset][ind, ...] = scribbles[ind, ...] # output_data[final_dset][ind, ...], _ = sess.run([mask_pl, softmax_pl], # feed_dict=feed_dict) output_data[final_dset].attrs.modify( 'processed_to', ind[-1] + 1) #Calculate dices _, dices, mean_dice, stdevs = calculate_dices( np.array(output_data[final_dset]), masks_gt, exp_config.nlabels, path=csv_path, filename=dset_names[2, 2].format(recursion)) output_data[final_dset].attrs.modify( 'mean_dice', mean_dice) output_data[final_dset].attrs.modify('dices', dices) output_data[final_dset].attrs.modify('stdevs', stdevs) output_data[final_dset].attrs.modify('processed', True) #print summaries: print("DICES:") l_str = " " * 40 for i in range(1, exp_config.nlabels - 1): d_str = "Dice 0{} (stdev)".format(i) print(d_str) print(d_str.center(20)) l_str += d_str.center(20) print(l_str + " Mean Dice ") print(" Random walker segmentations:".ljust(45) + dice_str(output_data['random_walk_mask'])) for recursion in range(0, num_recursions + 1): print(" RECURSION {}".format(recursion)) print(" Postprocessed input of recursion {}".format( recursion).ljust(45) + dice_str(output_data[dset_names[0, 0].format(recursion)])) print(" Random walked input of recursion {}".format( recursion).ljust(45) + dice_str(output_data[dset_names[0, 1].format(recursion)])) if recursion != num_recursions: print(" Output of recursion {}".format(recursion).ljust( 45) + dice_str(output_data[dset_names[0, 2].format(recursion)])) #Handle last recursion seperately if not skip_final_recursion: print(" Output of recursion {}".format(recursion).ljust(45) + dice_str(output_data[dset_names[0, 2].format(recursion)])) print( " Weights for predicting final masks were taken from epoch {} of {}" .format(epoch, exp_config.epochs_per_recursion)) #get graphs Mean Dice if not batch_size is None: indices = np.sort(np.unique(sample(range(len(images)), batch_size))) logging.info( "Showing segmentation progression for slices randomly picked slices {}" .format(indices)) descriptions = [] for index in indices: descriptions.append("Slice {:04}".format(index)) else: logging.info( "Showing segmentation progression for slices {0} [Best], {1} [Median] and {2} [Worst]" .format(indices[0], indices[1], indices[2])) descriptions = [ 'best slice ({:04})'.format(indices[0]), 'median slice ({:04})'.format(indices[1]), 'worst slice ({:04})'.format(indices[2]) ] batch_size = 3 for fig_idx in range(batch_size): graphs = np.zeros( (6 + num_recursions * 3, exp_config.image_size[0], exp_config.image_size[1])) slice_idx = indices[fig_idx] #First get image, ground truth and random walked prediction graphs[1, ...] = masks_gt[slice_idx, ...] graphs[2, ...] = output_data['random_walk_mask'][slice_idx, ...] for recursion in range(num_recursions + 1): for i in range(3): if i == 2 and recursion == num_recursions and skip_final_recursion: break dset_name = dset_names[0, i].format(recursion) graphs[recursion * 3 + i + 3, ...] = output_data[dset_name][slice_idx, ...] fig = plt.figure(fig_idx) fig.suptitle('Segmentation progress for {}'.format( descriptions[fig_idx])) #handle image seperately ax = fig.add_subplot(num_recursions + 2, 3, 1) ax.axis('off') ax.set_title('image') ax.imshow(np.squeeze(images[slice_idx, ...]), cmap='gray', vmin=0, vmax=exp_config.nlabels - 1) for graph_idx in range(1, len(graphs)): ax = fig.add_subplot(num_recursions + 2, 3, graph_idx + 1) ax.axis('off') # This should be cleaned up if graph_idx == 1: ax.set_title('ground truth') elif graph_idx == 2: ax.set_title('random walker segmentation') elif graph_idx == 3: ax.set_title('weak supervision') elif graph_idx == 4: ax.set_title('ws random walked') else: ax.set_title( dset_names[2, graph_idx % 3].format(int((graph_idx + 1) / 3) - 2)) ax.imshow(np.squeeze(graphs[graph_idx, ...]), cmap='jet', vmin=0, vmax=exp_config.nlabels - 1) plt.show()
def predict_segmentation(subject_name, image, normalize=True): # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[None] + list(exp_config.image_size) + [1], name='images') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, softmax, preds = model.predict_i2l(images_normalized, exp_config, training_pl=tf.constant( False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name i2l_vars.append(v) if 'image_normalizer' in var_name: normalization_vars.append(v) # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create saver # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_normalizer = tf.train.Saver(var_list=normalization_vars) # ================================================================ # freeze the graph before execution # ================================================================ tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ sess.run(init_ops) # ================================================================ # Restore the segmentation network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the normalization network parameters # ================================================================ if normalize is True: logging.info( '============================================================') path_to_model = os.path.join( sys_config.log_root, exp_config.expname_normalizer ) + '/subject_' + subject_name + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_score.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_normalizer.restore(sess, checkpoint_path) logging.info( '============================================================') # ================================================================ # Make predictions for the image at the resolution of the image after pre-processing # ================================================================ mask_predicted = [] img_normalized = [] for b_i in range(0, image.shape[0], 1): X = np.expand_dims(image[b_i:b_i + 1, ...], axis=-1) mask_predicted.append(sess.run(preds, feed_dict={images_pl: X})) img_normalized.append( sess.run(images_normalized, feed_dict={images_pl: X})) mask_predicted = np.squeeze(np.array(mask_predicted)).astype(float) img_normalized = np.squeeze(np.array(img_normalized)).astype(float) sess.close() return mask_predicted, img_normalized
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True): nx, ny = config.image_size[:2] batch_size = 1 num_channels = config.nlabels image_tensor_shape = [batch_size] + list(config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') # According to the experiment config, pick a model and predict the output mask_pl, softmax_pl = model.predict(images_pl, config) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 scale_vector = [config.pixel_size[0] / target_resolution[0], config.pixel_size[1] / target_resolution[1]] path_img = os.path.join(input_folder, 'img') if gt_exists: path_mask = os.path.join(input_folder, 'mask') for folder in os.listdir(path_img): logging.info(' ----- Doing image: -------------------------') logging.info('Doing: %s' % folder) logging.info(' --------------------------------------------') folder_path = os.path.join(path_img, folder) #ciclo su cartelle paz utils.makefolder(os.path.join(path_pred, folder)) if os.path.isdir(folder_path): for phase in os.listdir(folder_path): #ciclo su cartelle ED ES save_path = os.path.join(path_pred, folder, phase) utils.makefolder(save_path) predictions = [] mask_arr = [] img_arr = [] masks = [] imgs = [] path = os.path.join(folder_path, phase) for file in os.listdir(path): img = plt.imread(os.path.join(path,file)) if config.standardize: img = image_utils.standardize_image(img) if config.normalize: img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) img_arr.append(img) if gt_exists: for file in os.listdir(os.path.join(path_mask,folder,phase)): mask_arr.append(plt.imread(os.path.join(path_mask,folder,phase,file))) img_arr = np.transpose(np.asarray(img_arr),(1,2,0)) # x,y,N if gt_exists: mask_arr = np.transpose(np.asarray(mask_arr),(1,2,0)) start_time = time.time() if config.data_mode == '2D': for zz in range(img_arr.shape[2]): slice_img = np.squeeze(img_arr[:,:,zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') slice_mask = np.squeeze(mask_arr[:, :, zz]) slice_cropped = read_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny) slice_cropped = np.float32(slice_cropped) x = image_utils.reshape_2Dimage_to_tensor(slice_cropped) imgs.append(np.squeeze(x)) if gt_exists: mask_rescaled = transform.rescale(slice_mask, scale_vector, order=0, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') mask_cropped = read_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny) mask_cropped = np.asarray(mask_cropped, dtype=np.uint8) y = image_utils.reshape_2Dimage_to_tensor(mask_cropped) masks.append(np.squeeze(y)) # GET PREDICTION feed_dict = { images_pl: x, } mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) prediction_cropped = np.squeeze(logits_out[0,...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros((nx,ny,num_channels)) slice_predictions = prediction_cropped # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize(slice_predictions, (nx, ny, num_channels), order=1, preserve_range=True, anti_aliasing=True, mode='constant') else: prediction = transform.rescale(slice_predictions, (1.0/scale_vector[0], 1.0/scale_vector[1], 1), order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') prediction = np.uint8(np.argmax(prediction, axis=-1)) predictions.append(prediction) predictions = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0)) masks = np.transpose(np.asarray(masks, dtype=np.uint8), (1,2,0)) imgs = np.transpose(np.asarray(imgs, dtype=np.float32), (1,2,0)) # This is the same for 2D and 3D if do_postprocessing: predictions = image_utils.keep_largest_connected_components(predictions) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) # Save predicted mask for ii in range(predictions.shape[2]): image_file_name = os.path.join('paz', str(ii).zfill(3) + '.png') cv2.imwrite(os.path.join(save_path , image_file_name), np.squeeze(predictions[:,:,ii])) if gt_exists:
def main(exp_config, batch_size=3): # Load data data = h5py.File(sys_config.project_root + exp_config.scribble_data, 'r') slices = np.random.randint(low=0, high=data['images_test'].shape[0], size=batch_size) slices = np.sort(np.unique(slices)) slices = [80, 275, 370] batch_size = len(slices) images = data['images_test'][slices, ...] masks = data['masks_test'][slices, ...] #masks[masks == 0] = 4 num_recursions = most_recent_recursion(model_path) image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels) #mask_fs_pl, softmax_fs_pl = model.predict(images_pl, unet2D_bn_modified, 4) saver = tf.train.Saver() init = tf.global_variables_initializer() predictions = np.zeros([batch_size] + list(exp_config.image_size) + [num_recursions + 1]) feed_dict = { images_pl: np.expand_dims(images, -1), } path = '/scratch_net/' with tf.Session() as sess: sess.run(init) pred_size = 0 for recursion in range(num_recursions + 1): try: try: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model_best_dice.ckpt'.format(recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) mask_out, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) for mask in range(batch_size): predictions[ mask, ..., pred_size] = image_utils.keep_largest_connected_components( np.squeeze(mask_out[mask, ...])) print("Classified for recursion {}".format(recursion)) pred_size += 1 except Exception as e: print(e) num_recursions = pred_size fig = plt.figure() num_cols = num_recursions + 3 #RW: path = base_path + "/poster/" for recursion in range(num_recursions): predictions[..., recursion] = segment(images, np.squeeze(predictions[..., recursion]), beta=exp_config.rw_beta, threshold=0) for r in range(batch_size): #Add the image # ax = fig.add_subplot(batch_size, num_cols, 1 + r*num_cols) # ax.axis('off') # ax.imshow(np.squeeze(images[r, ...]), cmap='gray') image_utils.print_grayscale( np.squeeze(images[r, ...]), path, '{}_{}_image'.format(exp_config.experiment_name, slices[r])) #Add the mask # ax = fig.add_subplot(batch_size, num_cols, 2 + r*num_cols) # ax.axis('off') # ax.imshow(np.squeeze(masks[r, ...]), vmin=0, vmax=4, cmap='jet') image_utils.print_coloured( np.squeeze(masks[r, ...]), path, '{}_{}_gt'.format(exp_config.experiment_name, slices[r])) #predictions[r, ...] = segment(images, np.squeeze(predictions[r, ...]), beta=exp_config.rw_beta, threshold=0) for recursion in range(num_recursions): #Add each prediction image_utils.print_coloured( np.squeeze(predictions[r, ..., recursion]), path, '{}_{}_pred_r{}'.format(exp_config.experiment_name, slices[r], recursion)) #ax = fig.add_subplot(batch_size, num_cols, 3 + recursion + r*num_cols) #ax.axis('off') #ax.imshow(np.squeeze(predictions[r, ..., recursion]), vmin=0, vmax=4, cmap='jet') while True: plt.axis('off') plt.show()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) train_on_all_data = config.train_on_all_data # Load data data = read_data.load_and_maybe_process_data( input_folder=config.data_root, preprocessing_folder=config.preprocessing_folder, mode=config.data_mode, size=config.image_size, target_resolution=config.target_resolution, force_overwrite=False) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] labels_train = data['masks_train'] if not train_on_all_data: images_val = data['images_val'] labels_val = data['masks_val'] if config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) # if config.prob: #if prob is not 0 # logging.info('Before data_augmentation the number of training images is:') # logging.info(images_train.shape[0]) # #augmentation # image_aug, label_aug = aug.augmentation_function(images_train,labels_train) #num_aug = image_aug.shape[0] # id images augmented will be b'0.0' #id_aug = np.zeros([num_aug,]).astype('|S9') #concatenate #id_train = np.concatenate((id__train,id_aug)) # images_train = np.concatenate((images_train,image_aug)) # labels_train = np.concatenate((labels_train,label_aug)) # logging.info('After data_augmentation the number of training images is:') # logging.info(images_train.shape[0]) # else: # logging.info('No data_augmentation. Number of training images is:') # logging.info(images_train.shape[0]) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [config.batch_size] + list( config.image_size) + [1] mask_tensor_shape = [config.batch_size] + list(config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[]) training_pl = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_pl) # Build a Graph that computes predictions from the inference model. if (config.experiment_name == 'unet2D_valid' or config.experiment_name == 'unet2D_same' or config.experiment_name == 'unet2D_same_mod' or config.experiment_name == 'unet2D_light' or config.experiment_name == 'Dunet2D_same_mod' or config.experiment_name == 'Dunet2D_same_mod2' or config.experiment_name == 'Dunet2D_same_mod3'): logits = model.inference(images_pl, config, training=training_pl) elif config.experiment_name == 'ENet': with slim.arg_scope( model_structure.ENet_arg_scope(weight_decay=2e-4)): logits = model_structure.ENet( images_pl, num_classes=config.nlabels, batch_size=config.batch_size, is_training=True, reuse=None, num_initial_blocks=1, stage_two_repeat=2, skip_connections=config.skip_connections) else: logging.warning('invalid experiment_name!') logging.info('images_pl shape') logging.info(images_pl.shape) logging.info('labels_pl shape') logging.info(labels_pl.shape) logging.info('logits shape:') logging.info(logits.shape) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm] = model.loss(logits, labels_pl, nlabels=config.nlabels, loss_type=config.loss_type, weight_decay=config.weight_decay ) # second output is unregularised loss # record how Total loss and weight decay change over time tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if config.momentum is not None: train_op = model.training_step(loss, config.optimizer_handle, learning_rate_pl, momentum=config.momentum) else: train_op = model.training_step(loss, config.optimizer_handle, learning_rate_pl) # Add the Op to compare the logits to the labels during evaluation. # loss and dice on a minibatch eval_loss = model.evaluation(logits, labels_pl, images_pl, nlabels=config.nlabels, loss_type=config.loss_type) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. if train_on_all_data: max_to_keep = None else: max_to_keep = 5 saver = tf.train.Saver(max_to_keep=max_to_keep) saver_best_dice = tf.train.Saver() saver_best_xent = tf.train.Saver() # Create a session for running Ops on the Graph. configP = tf.ConfigProto() configP.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go configP.allow_soft_placement = True # If a operation is not define it the default device, let it execute in another. sess = tf.Session(config=configP) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge([val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 for epoch in range(config.max_epochs): logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches( images_train, labels_train, batch_size=config.batch_size, augment_batch=config.augment_batch): if config.warmup_training: if step < 50: curr_lr = config.learning_rate / 10.0 elif step == 50: curr_lr = config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches) if y.shape[0] < config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 20 == 0: # Print status to stdout. logging.info('Step %d: loss = %.3f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss, train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_train, labels_train, config.batch_size) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_dice_: train_dice }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if config.schedule_lr and loss_gradient < config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: no_improvement_counter = 0 logging.info('Decrease in training error!') else: no_improvement_counter = no_improvement_counter + 1 logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. if not train_on_all_data: # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_val, labels_val, config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'model_best_dice.ckpt') filelist = glob.glob( os.path.join(log_dir, 'model_best_dice*')) for file in filelist: os.remove(file) saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - %f - Saving model_best_dice.ckpt' % val_dice) if val_loss < best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') filelist = glob.glob( os.path.join(log_dir, 'model_best_xent*')) for file in filelist: os.remove(file) saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 # end epoch if (epoch + 1) % config.epoch_freq == 0: curr_lr = curr_lr * 0.98 logging.info('Learning rate: %f' % curr_lr) sess.close() data.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) if hasattr(exp_config, 'train_on_all_data'): train_on_all_data = exp_config.train_on_all_data else: train_on_all_data = False # Load data data = acdc_data.load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=exp_config.data_mode, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False, split_test_train=(not train_on_all_data)) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] labels_train = data['masks_train'] if not train_on_all_data: images_val = data['images_test'] labels_val = data['masks_test'] if exp_config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * exp_config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[]) training_pl = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_pl) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_pl, exp_config, training=training_pl) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm] = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay ) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if exp_config.momentum is not None: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_pl, momentum=exp_config.momentum) else: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_pl) # Add the Op to compare the logits to the labels during evaluation. eval_loss = model.evaluation(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. if train_on_all_data: max_to_keep = None else: max_to_keep = 5 saver = tf.train.Saver(max_to_keep=max_to_keep) saver_best_dice = tf.train.Saver() saver_best_xent = tf.train.Saver() # Create a session for running Ops on the Graph. config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not define it the default device, let it execute in another. sess = tf.Session(config=config) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge([val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 for epoch in range(exp_config.max_epochs): logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches( images_train, labels_train, batch_size=exp_config.batch_size, augment_batch=exp_config.augment_batch): # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the # training speed. However, be aware that currently an exception inside this loop may not be caught. # The batch generator may just continue running silently without warning eventhough the code has # crashed. # for batch in BackgroundGenerator(iterate_minibatches(images_train, # labels_train, # batch_size=exp_config.batch_size, # augment_batch=exp_config.augment_batch)): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches if y.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss, train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_train, labels_train, exp_config.batch_size) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_dice_: train_dice }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: logging.info('Decrease in training error!') else: logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. if not train_on_all_data: # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_val, labels_val, exp_config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'model_best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - %f - Saving model_best_dice.ckpt' % val_dice) if val_loss < best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 sess.close() data.close()
def main(exp_config): # Load data data = acdc_data.load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=exp_config.data_mode, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False) batch_size = 1 image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') #mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) logits = model.inference(images_pl, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) softmax_pl = tf.nn.softmax(logits) threshold = tf.constant(0.95, dtype=tf.float32) s = tf.multiply(tf.ones(shape=[1, 212, 212, 1]), threshold) softmax_pl = tf.concat([s, softmax_pl[..., 1:]], axis=-1) mask_pl = tf.arg_max(logits, dimension=-1) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) #checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_1_model_best_dice.ckpt') saver.restore(sess, checkpoint_path) for i in range(10, 20): ind = i #np.random.randint(data['images_test'].shape[0]) x = data['images_test'][ind, ...] y = data['masks_test'][ind, ...] x = image_utils.reshape_2Dimage_to_tensor(x) y = image_utils.reshape_2Dimage_to_tensor(y) feed_dict = {images_pl: x, training_time_placeholder: False} mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) #postprocessing fig = plt.figure() ax1 = fig.add_subplot(251) ax1.imshow(np.squeeze(x), cmap='gray') ax2 = fig.add_subplot(252) ax2.imshow(np.squeeze(y)) ax3 = fig.add_subplot(253) ax3.imshow(np.squeeze(mask_out)) ax5 = fig.add_subplot(256) ax5.imshow(np.squeeze(softmax_out[..., 0])) ax6 = fig.add_subplot(257) ax6.imshow(np.squeeze(softmax_out[..., 1])) ax7 = fig.add_subplot(258) ax7.imshow(np.squeeze(softmax_out[..., 2])) ax8 = fig.add_subplot(259) ax8.imshow(np.squeeze(softmax_out[..., 3])) ax8 = fig.add_subplot(2, 5, 10) ax8.imshow(np.squeeze(softmax_out[..., 4])) plt.show()
def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, recursion=None): print("KOD YENİ") dices = [] images, labels = read_data('/scratch/cany/scribble/scribble_data/prostate_divided.h5') num_images = images.shape[0] print(str(num_images)) print(str(images.shape)) nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = exp_config.nlabels image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl,logits = model.predict_logits(images_pl, exp_config.model_handle, exp_config.nlabels) mask_tensor_shape = [batch_size] + list(exp_config.image_size) # images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') # Add to the Graph the Ops for loss calculation. # eval_val_loss = model.evaluation(logits, # labels_placeholder, # images_pl, # nlabels=exp_config.nlabels, # loss_type=exp_config.loss_type, # weak_supervision=True, # cnn_threshold=exp_config.cnn_threshold, # include_bg=False) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) if recursion is None: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') else: try: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'recursion_{}_model_best_dice.ckpt'.format(recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) for k in range(num_images): network_input = np.expand_dims(np.expand_dims(images[k,:,:],2),0) mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict={images_pl: network_input}) prediction_cropped = np.squeeze(logits_out[0, ...]) # ASSEMBLE BACK THE SLICES prediction_arr = np.uint8(np.argmax(prediction_cropped, axis=-1)) # prediction_arr = np.squeeze(np.transpose(np.asarray(prediction, dtype=np.uint8), (1,2,0))) # mask = labels[k,:,:] # This is the same for 2D and 3D again if do_postprocessing: print("Entered post processing " + str(True)) prediction_arr = image_utils.keep_largest_connected_components(prediction_arr) # Save predicted mask out_file_name = os.path.join(output_folder, 'prediction', 'patient' + str(k) +'.nii.gz') logging.info('saving to: %s' % out_file_name) save_nii(out_file_name, prediction_arr) # Save GT image gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + str(k) + '.nii.gz') logging.info('saving to: %s' % gt_file_name) save_nii(gt_file_name, np.uint8(mask)) # # Save difference mask between predictions and ground truth # difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0]) # difference_mask = np.asarray(difference_mask, dtype=np.uint8) # diff_file_name = os.path.join(output_folder, # 'difference', # 'patient' + str(k) + '.nii.gz') # logging.info('saving to: %s' % diff_file_name) # save_nii(diff_file_name, difference_mask) # Save image data to the same folder for convenience image_file_name = os.path.join(output_folder, 'image', 'patient' + str(k) + '.nii.gz') logging.info('saving to: %s' % image_file_name) save_nii(image_file_name, images[k,:,:]) # feed_dict = { images_pl: network_input, # labels_placeholder: np.expand_dims(np.squeeze(labels[k,:,:]),0), # } # # closs, cdice = sess.run(eval_val_loss, feed_dict=feed_dict) # print(str(prediction_arr.shape)) # tempp= np.expand_dims(np.squeeze(labels[k,:,:]),0) # print(str(tempp.shape)) # qwe=tf.one_hot(np.uint8(np.squeeze(labels[k,:,:])), depth=4) # print(str(sess.run(tf.shape(qwe)))) # tempp2 = tf.one_hot(prediction_arr, depth=4) # print(str(sess.run(tf.shape(tempp2)))) cdice = sess.run(get_dice(tf.one_hot(np.uint8(prediction_arr), depth=4),np.uint8(np.squeeze(labels[k,:,:])),4)) print(str(cdice)) # [val_loss, val_dice] = do_eval(sess, # eval_val_loss, # images_placeholder, # labels_placeholder, # network_input, # np.expand_dims(np.squeeze(labels[k,:,:]),0), # exp_config.batch_size) dices.append(cdice) print("Average Dice : " + str(np.mean(dices))) return init_iteration
def run_training(): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # Determine the data set # ============================ target_data = True # ============================ # Load data # ============================ if target_data: hcp = True if hcp: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imts, gtts = data_hcp.load_data(sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T2', 1, exp_config.test_size + 1) logging.info('Test Images: %s' % str( imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Test Labels: %s' % str( gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') else: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading ABIDE caltech...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide) imts, gtts = data_abide.load_data(sys_config.orig_data_root_abide, sys_config.preproc_folder_abide, 1, exp_config.test_size + 1) logging.info('Test Images: %s' % str( imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Test Labels: %s' % str( gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') else: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imts, gtts = data_hcp.load_data( sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', exp_config.train_size + exp_config.val_size + 1, exp_config.train_size + exp_config.val_size + exp_config.test_size + 1) logging.info( 'Test Images: %s' % str(imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Test Labels: %s' % str(gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') # ============================ # Remove exclusively black images # ============================ mask = ~(imts == 0).all(axis=(1, 2)) imts = imts[mask] gtts = gtts[mask] # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') # Placeholders for the images and labels image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_pl = tf.compat.v1.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_tensor_shape_segmentation = [exp_config.batch_size] + list( exp_config.image_size) labels_tensor_shape_self_supervised = [exp_config.batch_size] + [ exp_config.num_rotation_values ] labels_segmentation_pl = tf.compat.v1.placeholder( tf.uint8, shape=labels_tensor_shape_segmentation, name='labels_segmentation') labels_self_supervised_pl = tf.compat.v1.placeholder( tf.float16, shape=labels_tensor_shape_self_supervised, name='labels_self_supervised') # Placeholders for the learning rate and to indicate whether the model is being trained or tested learning_rate_pl = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.compat.v1.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # Define the image normalization function - these parameters will be updated for each test image # ================================================================ image_normalized = exp_config.model_handle_normalizer( images_pl, image_tensor_shape) # ================================================================ # Define the segmentation network here # ================================================================ logits_segmentation = exp_config.model_handle_segmentation( image_normalized, image_tensor_shape, training_pl, exp_config.nlabels) # ================================================================ # Define the self-supervised network here # ================================================================ logits_self_supervised = exp_config.model_handle_rotation( image_normalized, image_tensor_shape, training_pl) # ================================================================ # determine trainable variables # ================================================================ segmentation_vars = [] self_supervised_vars = [] test_time_opt_vars = [] for v in tf.compat.v1.trainable_variables(): var_name = v.name if 'rotation' in var_name: self_supervised_vars.append(v) elif 'segmentation' in var_name: segmentation_vars.append(v) elif 'normalization' in var_name: test_time_opt_vars.append(v) if exp_config.debug is True: logging.info('================================') logging.info('List of trainable variables in the graph:') for v in tf.compat.v1.trainable_variables(): logging.info(v.name) logging.info('================================') logging.info('List of all segmentation variables:') for v in segmentation_vars: logging.info(v.name) logging.info('================================') logging.info('List of all self-supervised variables:') for v in self_supervised_vars: logging.info(v.name) logging.info('================================') logging.info('List of all test time variables:') for v in test_time_opt_vars: logging.info(v.name) # ================================================================ # Add ops for calculation of the training loss # ================================================================ loss_self_supervised = tf.reduce_mean( exp_config.loss_handle_self_supervised(labels_self_supervised_pl, logits_self_supervised)) tf.compat.v1.summary.scalar('loss_self_supervised', loss_self_supervised) # ================================================================ # Add optimization ops. # ================================================================ train_op_test_time = model.training_step( loss_self_supervised, test_time_opt_vars, exp_config.optimizer_handle_test_time, learning_rate_pl) # ================================================================ # Add ops for model evaluation # ================================================================ eval_loss_segmentation = model.evaluation( logits_segmentation, labels_segmentation_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) eval_loss_self_supervised = model.evaluation_rotate( logits_self_supervised, labels_self_supervised_pl) # ================================================================ # Build the summary Tensor based on the TF collection of Summaries. # ================================================================ summary = tf.compat.v1.summary.merge_all() # ================================================================ # Add init ops # ================================================================ init_g = tf.compat.v1.global_variables_initializer() init_l = tf.compat.v1.local_variables_initializer() # ================================================================ # Find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.compat.v1.report_uninitialized_variables() # ================================================================ # create savers for each domain # ================================================================ max_to_keep = 15 saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep) saver_best_da = tf.compat.v1.train.Saver() saver_seg = tf.compat.v1.train.Saver(var_list=segmentation_vars) saver_ss = tf.compat.v1.train.Saver(var_list=self_supervised_vars) # ================================================================ # Create session # ================================================================ config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.compat.v1.Session(config=config) # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the test errors # ================================================================ ts_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_error_seg') ts_error_summary_seg = tf.compat.v1.summary.scalar( 'test/loss_seg', ts_error_seg) ts_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_dice_seg') ts_dice_summary_seg = tf.compat.v1.summary.scalar( 'test/dice_seg', ts_dice_seg) ts_summary_seg = tf.compat.v1.summary.merge( [ts_error_summary_seg, ts_dice_summary_seg]) ts_error_ss = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_error_ss') ts_error_summary_ss = tf.compat.v1.summary.scalar( 'test/loss_ss', ts_error_ss) ts_acc_ss = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_acc_ss') ts_acc_summary_ss = tf.compat.v1.summary.scalar( 'test/acc_ss', ts_acc_ss) ts_summary_ss = tf.compat.v1.summary.merge( [ts_error_summary_ss, ts_acc_summary_ss]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.compat.v1.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_g) sess.run(init_l) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: logging.info(v) # ================================================================ # restore shared weights # ================================================================ logging.info( '============================================================') logging.info('Restore segmentation...') model_path = os.path.join(sys_config.log_root, 'Initial_training_segmentation') checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'models/best_dice.ckpt') logging.info('Restroring session from: %s' % checkpoint_path) saver_seg.restore(sess, checkpoint_path) logging.info( '============================================================') logging.info('Restore self-supervised...') model_path = os.path.join(sys_config.log_root, 'Initial_training_rotation') checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'models/best_dice.ckpt') logging.info('Restroring session from: %s' % checkpoint_path) saver_ss.restore(sess, checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_da = 0 # ================================================================ # run training epochs # ================================================================ for epoch in range(exp_config.max_epochs): logging.info( '============================================================') logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches(imts, gtts, batch_size=exp_config.batch_size): curr_lr = exp_config.learning_rate start_time = time.time() x, y_seg, y_ss = batch # =========================== # avoid incomplete batches # =========================== if y_seg.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_self_supervised_pl: y_ss, learning_rate_pl: curr_lr, training_pl: True } # =========================== # update vars # =========================== _, loss_value = sess.run( [train_op_test_time, loss_self_supervised], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.2f (%.3f sec for the last step)' % (step + 1, loss_value, duration)) # =========================== # print values of a parameter (to debug) # =========================== if exp_config.debug is True: var_value = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES )[0].eval( session=sess ) # can add name if you only want parameters with a specific name logging.info('value of one of the parameters %f' % var_value[0, 0, 0, 0]) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # compute the loss on the entire test set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Test Data Eval:') [test_loss_seg, test_dice_seg, test_loss_ss, test_acc_ss ] = do_eval(sess, eval_loss_segmentation, eval_loss_self_supervised, images_pl, labels_segmentation_pl, labels_self_supervised_pl, training_pl, imts, gtts, exp_config.batch_size) ts_summary_msg = sess.run(ts_summary_seg, feed_dict={ ts_error_seg: test_loss_seg, ts_dice_seg: test_dice_seg }) summary_writer.add_summary(ts_summary_msg, step) ts_summary_msg = sess.run(ts_summary_ss, feed_dict={ ts_error_ss: test_loss_ss, ts_acc_ss: test_acc_ss }) summary_writer.add_summary(ts_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) step += 1 sess.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) already_created_recursion = True print("ALready created recursion : " + str(already_created_recursion)) init_step = 0 # Load data base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles( scribble_file=sys_config.project_root + exp_config.scribble_data, target_folder=log_dir, percent_full_sup=exp_config.percent_full_sup, scr_ratio=exp_config.length_ratio) #wrap everything from this point onwards in a try-except to catch keyboard interrupt so #can control h5py closing data try: loaded_previous_recursion = False start_epoch = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) except: print("EXCEPTE GİRDİ") init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1)) loaded_previous_recursion = True logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval start_epoch = int(init_step / (len(base_data['images_train']) / 4)) logging.info('Latest step was: %d' % init_step) logging.info('Continuing with epoch: %d' % start_epoch) except: logging.warning( '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 start_epoch = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) if loaded_previous_recursion: logging.info( "Data file exists for recursion {} " "but checkpoints only present up to recursion {}".format( recursion, recursion - 1)) logging.info("Likely means postprocessing was terminated") # if not already_created_recursion: # # recursion_data = acdc_data.load_different_recursion(recursion_data, -1) # recursion-=1 # else: start_epoch = 0 init_step = 0 # load images and validation data images_train = np.array(base_data['images_train']) scribbles_train = np.array(base_data['scribbles_train']) images_val = np.array(base_data['images_test']) labels_val = np.array(base_data['masks_test']) # if exp_config.use_data_fraction: # num_images = images_train.shape[0] # new_last_index = int(float(num_images)*exp_config.use_data_fraction) # # logging.warning('USING ONLY FRACTION OF DATA!') # logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) # images_train = images_train[0:new_last_index,...] # labels_train = labels_train[0:new_last_index,...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) #logging.info(' - Labels:') #logging.info(labels_train.shape) #logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # with tf.Graph().as_default(): with tf.Session(config=config) as sess: # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) keep_prob = tf.placeholder(tf.float32, shape=[]) crf_learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_placeholder, keep_prob, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm ] = model.loss(logits, labels_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay ) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. global_step = tf.Variable(0, name='global_step', trainable=False) crf_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='crf_scope') restore_var = [ v for v in tf.all_variables() if v.name not in crf_variables ] global_step = tf.Variable(0, name='global_step', trainable=False) network_train_op = tf.train.AdamOptimizer( learning_rate=learning_rate_placeholder).minimize( loss, var_list=restore_var, colocate_gradients_with_ops=True, global_step=global_step) crf_train_op = tf.train.AdamOptimizer( learning_rate=crf_learning_rate_placeholder).minimize( loss, var_list=crf_variables, colocate_gradients_with_ops=True, global_step=global_step) eval_val_loss = model.evaluation( logits, labels_placeholder, images_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weak_supervision=True, cnn_threshold=exp_config.cnn_threshold, include_bg=False) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. # Only keep two checkpoints, as checkpoints are kept for every recursion # and they can be 300MB + saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge( [val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) # if continue_run: # # Restore session # saver.restore(sess, init_checkpoint_path) # saver.restore(sess,"/scratch_net/biwirender02/cany/scribble/logdir/heart_dropout_rnn_exp/recursion_1_model_best_dice.ckpt-12699") init_step = 0 recursion = 0 start_epoch = 0 # step = init_step curr_lr = exp_config.learning_rate / 10 crf_curr_lr = 1e-07 / 10 no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 logging.info('RECURSION {0}'.format(recursion)) # random walk - if it already has been random walked it won't redo if recursion == 0: recursion_data = acdc_data.random_walk_epoch( recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) print("Random walku geçti") #get ground truths labels_train = np.array(recursion_data['random_walked']) else: labels_train = np.array(recursion_data['predicted']) print("Start epoch : " + str(start_epoch) + " : max epochs : " + str(exp_config.epochs_per_recursion)) for epoch in range(start_epoch, exp_config.max_epochs): if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0): #Have reached end of recursion recursion_data = predict_next_gt( data=recursion_data, images_train=images_train, images_placeholder=images_placeholder, training_time_placeholder=training_time_placeholder, keep_prob=keep_prob, logits=logits, sess=sess) # recursion_data = postprocess_gt(data=recursion_data, # images_train=images_train, # scribbles_train=scribbles_train) recursion += 1 # random walk - if it already has been random walked it won't redo # recursion_data = acdc_data.random_walk_epoch(recursion_data, # exp_config.rw_beta, # exp_config.rw_threshold, # exp_config.random_walk) #get ground truths labels_train = np.array(recursion_data['predicted']) #reinitialise savers - otherwise, no checkpoints will be saved for each recursion saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) logging.info( 'Epoch {0} ({1} of {2} epochs for recursion {3})'.format( epoch, 1 + epoch % exp_config.epochs_per_recursion, exp_config.epochs_per_recursion, recursion)) # for batch in iterate_minibatches(images_train, # labels_train, # batch_size=exp_config.batch_size, # augment_batch=exp_config.augment_batch): # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the # training speed. However, be aware that currently an exception inside this loop may not be caught. # The batch generator may just continue running silently without warning even though the code has # crashed. for batch in BackgroundGenerator( iterate_minibatches( images_train, labels_train, batch_size=exp_config.batch_size, augment_batch=exp_config.augment_batch)): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate if ((step % 3000 == 0) & (step > 0)): curr_lr = curr_lr * 0.9 crf_curr_lr = crf_curr_lr * 0.9 start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches if y.shape[0] < exp_config.batch_size: step += 1 continue network_feed_dict = { images_placeholder: x, labels_placeholder: y, learning_rate_placeholder: curr_lr, keep_prob: 0.5, training_time_placeholder: True } crf_feed_dict = { images_placeholder: x, labels_placeholder: y, crf_learning_rate_placeholder: crf_curr_lr, keep_prob: 1, training_time_placeholder: True } if (step % 10 == 0): _, loss_value = sess.run([crf_train_op, loss], feed_dict=crf_feed_dict) _, loss_value = sess.run([network_train_op, loss], feed_dict=network_feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.6f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice ] = do_eval(sess, eval_val_loss, images_placeholder, labels_placeholder, training_time_placeholder, keep_prob, images_val, labels_val, exp_config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join( log_dir, 'recursion_{}_model_best_dice.ckpt'.format( recursion)) saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - {} - ' 'Saving recursion_{}_model_best_dice.ckpt'. format(val_dice, recursion)) text_file = open('val_results.txt', "a") text_file.write("\nVal dice " + str(step) + " : " + str(val_dice)) text_file.close() if val_loss < best_val: best_val = val_loss best_file = os.path.join( log_dir, 'recursion_{}_model_best_xent.ckpt'.format( recursion)) saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - {} - ' 'Saving recursion_{}_model_best_xent.ckpt'. format(val_loss, recursion)) step += 1 except Exception: raise
# ================================================================ sess = tf.Session() # ==================================== # Initialize # ==================================== sess.run(init_g) sess.run(init_l) # ==================================== # Restore training models # ==================================== log_dir = os.path.join(sys_config.log_root, exp_config.experiment_name) path_to_model = log_dir checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'models/best_dice.ckpt') logging.info( '========================================================') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver.restore(sess, checkpoint_path) # ================================ # open a text file for writing the mean dice scores for each subject that is evaluated # ================================ if test_dataset is 'freiburg': subject_string = '000' results_file = open( log_dir + '/results/' + test_dataset + '/' + subject_string + '.txt', "w") results_file.write("================================== \n")
def run_training(log_dir, image, label, atlas, continue_run, log_dir_first_TD_subject=''): # ============================ # down sample the atlas - the losses will be evaluated in the downsampled space # ============================ atlas_downsampled = rescale(atlas, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') atlas_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot( atlas_downsampled, int(256 / exp_config.downsampling_factor_x)) label_onehot = utils.make_onehot(label, exp_config.nlabels) label_onehot_downsampled = rescale(label_onehot, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant') label_onehot_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot( label_onehot_downsampled, int(256 / exp_config.downsampling_factor_x)) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ================================================================ # reset the graph built so far and build a new TF graph # ================================================================ tf.reset_default_graph() with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders - segmentation net # ================================================================ images_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module is trained for each test image # ================================================================ images_normalized, added_residual = model.normalize( images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # By setting the 'training_pl' to false directly, the update ops for the moments in the BN layer are not created at all. # This allows grouping the update ops together with the optimizer training, while training the normalizer - in case the normalizer has BN. # ================================================================ predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l( images_normalized, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # 3d prior # ================================================================ labels_3d_1hot_shape = [1] + list( exp_config.image_size_downsampled) + [exp_config.nlabels] # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder predicted_seg_1hot_3d_pl = tf.placeholder(tf.float32, shape=labels_3d_1hot_shape, name='predicted_labels_3d') # denoise the noisy segmentation _, predicted_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l( predicted_seg_1hot_3d_pl, exp_config, training_pl=tf.constant(False, dtype=tf.bool)) # ================================================================ # divide the vars into segmentation network and normalization network # ================================================================ i2l_vars = [] l2l_vars = [] normalization_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'l2l_mapper' in var_name: l2l_vars.append(v) # ================================================================ # Make a list of trainable i2l vars. This will be used to compute gradients. # The other list contains trainable as well as non-trainable parameters. This is required for saving and loading all parameters, but runs into trouble when gradients are asked to be computed for non-trainable parameters # ================================================================ i2l_vars_trainable = [] for v in i2l_vars: if v.trainable is True: i2l_vars_trainable.append(v) # ================================================================ # add ops for calculation of the prior loss - wrt an atlas or the outputs of the DAE # ================================================================ prior_label_1hot_pl = tf.placeholder( tf.float32, shape=[exp_config.batch_size_downsampled] + list( (exp_config.image_size_downsampled[1], exp_config.image_size_downsampled[2])) + [exp_config.nlabels], name='labels_prior') # down sample the predicted logits predicted_seg_logits_expanded = tf.expand_dims(predicted_seg_logits, axis=0) # the 'upsample' function will actually downsample the predictions, as the scaling factors have been set appropriately predicted_seg_logits_downsampled = layers.bilinear_upsample3D_( predicted_seg_logits_expanded, name='downsampled_predictions', factor_x=1 / exp_config.downsampling_factor_x, factor_y=1 / exp_config.downsampling_factor_y, factor_z=1 / exp_config.downsampling_factor_z) predicted_seg_logits_downsampled = tf.squeeze( predicted_seg_logits_downsampled ) # the first axis was added only for the downsampling in 3d # compute the dice between the predictions and the prior in the downsampled space loss_op = model.loss(logits=predicted_seg_logits_downsampled, labels=prior_label_1hot_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_prior, mask_for_loss_within_mask=None, are_labels_1hot=True) tf.summary.scalar('tr_losses/loss', loss_op) # ================================================================ # one of the two prior losses will be used in the following manner: # the atlas prior will be used when the current prediction is deemed to be very far away from a reasonable solution # once a reasonable solution is reached, the dae prior will be used. # these 3d computations will be done outside the graph and will be passed via placeholders for logging in tensorboard # ================================================================ lambda_prior_atlas_pl = tf.placeholder(tf.float32, shape=[], name='lambda_prior_atlas') lambda_prior_dae_pl = tf.placeholder(tf.float32, shape=[], name='lambda_prior_dae') tf.summary.scalar('lambdas/prior_atlas', lambda_prior_atlas_pl) tf.summary.scalar('lambdas/prior_dae', lambda_prior_dae_pl) dice3d_prior_atlas_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_prior_atlas') dice3d_prior_dae_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_prior_dae') dice3d_gt_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_gt') tf.summary.scalar('dice3d/prior_atlas', dice3d_prior_atlas_pl) tf.summary.scalar('dice3d/prior_dae', dice3d_prior_dae_pl) tf.summary.scalar('dice3d/gt', dice3d_gt_pl) # ================================================================ # add optimization ops # ================================================================ if exp_config.debug: print('creating training op...') # create an instance of the required optimizer optimizer = exp_config.optimizer_handle(learning_rate=learning_rate_pl) # initialize variable holding the accumlated gradients and create a zero-initialisation op accumulated_gradients = [ tf.Variable(tf.zeros_like(var.initialized_value()), trainable=False) for var in i2l_vars_trainable ] # accumulated gradients init op accumulated_gradients_zero_op = [ ac.assign(tf.zeros_like(ac)) for ac in accumulated_gradients ] # calculate gradients and define accumulation op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): gradients = optimizer.compute_gradients( loss_op, var_list=i2l_vars_trainable) # compute_gradients return a list of (gradient, variable) pairs. accumulate_gradients_op = [ ac.assign_add(gg[0]) for ac, gg in zip(accumulated_gradients, gradients) ] # define the gradient mean op num_accumulation_steps_pl = tf.placeholder( dtype=tf.float32, name='num_accumulation_steps') accumulated_gradients_mean_op = [ ag.assign(tf.divide(ag, num_accumulation_steps_pl)) for ag in accumulated_gradients ] # reassemble the gradients in the [value, var] format and do define train op final_gradients = [(ag, gg[1]) for ag, gg in zip(accumulated_gradients, gradients)] train_op = optimizer.apply_gradients(final_gradients) # ================================================================ # sequence of running opt ops: # 1. at the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders) # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs # 3. at the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients. # ================================================================ # ================================================================ # previous train_op without accumulation of gradients # ================================================================ # train_op = model.training_step(loss_op, normalization_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars = True) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ if exp_config.debug: print('creating summary op...') # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ if exp_config.debug: logging.info( 'Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the training errors # ================================================================ prior_dae_dice = tf.placeholder(tf.float32, shape=[], name='prior_dae_dice') prior_dae_dice_summary = tf.summary.scalar('test_img/prior_dae_dice', prior_dae_dice) prior_dae_output_dice_wrt_gt = tf.placeholder( tf.float32, shape=[], name='prior_dae_output_dice_wrt_gt') prior_dae_output_dice_wrt_gt_summary = tf.summary.scalar( 'test_img/prior_dae_output_dice_wrt_gt', prior_dae_output_dice_wrt_gt) prior_atlas_dice = tf.placeholder(tf.float32, shape=[], name='prior_atlas_dice') prior_atlas_dice_summary = tf.summary.scalar( 'test_img/prior_atlas_dice', prior_atlas_dice) prior_dae_atlas_dice_ratio = tf.placeholder( tf.float32, shape=[], name='prior_dae_atlas_dice_ratio') prior_dae_atlas_dice_ratio_summary = tf.summary.scalar( 'test_img/prior_dae_atlas_dice_ratio', prior_dae_atlas_dice_ratio) prior_dice = tf.placeholder(tf.float32, shape=[], name='prior_dice') prior_dice_summary = tf.summary.scalar('test_img/prior_dice', prior_dice) gt_dice = tf.placeholder(tf.float32, shape=[], name='gt_dice') gt_dice_summary = tf.summary.scalar('test_img/gt_dice', gt_dice) # ================================================================ # create savers # ================================================================ saver_i2l = tf.train.Saver(var_list=i2l_vars) saver_l2l = tf.train.Saver(var_list=l2l_vars) saver_test_data = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) saver_best_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) # ================================================================ # add operations to compute dice between two 3d volumes # ================================================================ pred_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='pred_3d') labl_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='labl_3d') atls_3d_1hot_pl = tf.placeholder( tf.float32, shape=list(exp_config.image_size_downsampled) + [exp_config.nlabels], name='atls_3d') dice_3d_op_dae = losses.compute_dice_3d_without_batch_axis( prediction=pred_3d_1hot_pl, labels=labl_3d_1hot_pl) dice_3d_op_atlas = losses.compute_dice_3d_without_batch_axis( prediction=pred_3d_1hot_pl, labels=atls_3d_1hot_pl) # ================================================================ # freeze the graph before execution # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of uninitialized variables # ================================================================ uninit_variables = sess.run(uninit_vars) if exp_config.debug: logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') for v in uninit_variables: print(v) # ================================================================ # Restore the segmentation network parameters and the pre-trained i2i mapper parameters # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_i2l.restore(sess, checkpoint_path) # ================================================================ # Restore the prior network parameters # ================================================================ logging.info( '============================================================') path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_l2l.restore(sess, checkpoint_path) # ================================================================ # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters # ================================================================ if log_dir_first_TD_subject is not '': logging.info( '============================================================') path_to_model = log_dir_first_TD_subject + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_score.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_test_data.restore(sess, checkpoint_path) max_steps_tta = exp_config.max_steps else: max_steps_tta = 5 * exp_config.max_steps # run the adaptation of the 1st TD subject for longer # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring normalization module from: %s' % init_checkpoint_path) saver_test_data.restore(sess, init_checkpoint_path) # ================================================================ # run training epochs # ================================================================ step = init_step best_score = 0.0 while (step < max_steps_tta): # ================================================ # After every some epochs, # Get the prediction for the entire volume, evaluate it using the DAE. # Now, decide whether to use the DAE output or the atlas as the ground truth for the next update. # ================================================ if (step == init_step) or (step % exp_config.check_ood_frequency is 0): # ================== # 1. compute the current 3d segmentation prediction # ================== y_pred_soft = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): y_pred_soft.append( sess.run(predicted_seg_softmax, feed_dict={ images_pl: batch, training_pl: False })) y_pred_soft = np.squeeze(np.array(y_pred_soft)).astype(float) y_pred_soft = np.reshape(y_pred_soft, [ -1, y_pred_soft.shape[2], y_pred_soft.shape[3], y_pred_soft.shape[4] ]) # ================== # 2. downsample it. Let's call this guy 'A' # ================== y_pred_soft_downsampled = rescale(y_pred_soft, [ 1 / exp_config.downsampling_factor_x, 1 / exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z ], order=1, preserve_range=True, multichannel=True, mode='constant').astype( np.float32) # ================== # 3. pass the downsampled prediction through the DAE and get its output 'B' # ================== feed_dict = { predicted_seg_1hot_3d_pl: np.expand_dims(y_pred_soft_downsampled, axis=0) } y_pred_noisy_denoised_softmax = np.squeeze( sess.run( predicted_seg_softmax_3d_noisy_autoencoded_softmax, feed_dict=feed_dict)).astype(np.float16) y_pred_noisy_denoised = np.argmax( y_pred_noisy_denoised_softmax, axis=-1) # ================== # 4. compute the dice between: # a. 'A' (seg network prediction downsampled) and 'B' (dae network output) # b. 'B' (dae network output) and downsampled gt labels (for debugging, to see if the dae output is close to the gt.) # c. 'A' (seg network prediction downsampled) and 'C' (downsampled atlas) # ================== dAB = sess.run(dice_3d_op_dae, feed_dict={ pred_3d_1hot_pl: y_pred_soft_downsampled, labl_3d_1hot_pl: y_pred_noisy_denoised_softmax }) dBgt = sess.run(dice_3d_op_dae, feed_dict={ pred_3d_1hot_pl: y_pred_noisy_denoised_softmax, labl_3d_1hot_pl: label_onehot_downsampled }) dAC = sess.run(dice_3d_op_atlas, feed_dict={ pred_3d_1hot_pl: y_pred_soft_downsampled, atls_3d_1hot_pl: atlas_downsampled }) # ================== # 5. compute the ratio dice(AB) / dice(AC). pass the ratio through a threshold and decide whether to use the DAE or the atlas as the prior # ================== ratio_dice = dAB / (dAC + 1e-5) if exp_config.use_gt_for_tta is True: target_labels_for_this_epoch = label_onehot_downsampled prr = dBgt elif (ratio_dice > exp_config.dae_atlas_ratio_threshold) and ( dAC > exp_config.min_atlas_dice): target_labels_for_this_epoch = y_pred_noisy_denoised_softmax prr = dAB else: target_labels_for_this_epoch = atlas_downsampled prr = dAC # ================== # update losses on tensorboard # ================== summary_writer.add_summary( sess.run(prior_dae_dice_summary, feed_dict={prior_dae_dice: dAB}), step) summary_writer.add_summary( sess.run(prior_dae_output_dice_wrt_gt_summary, feed_dict={prior_dae_output_dice_wrt_gt: dBgt}), step) summary_writer.add_summary( sess.run(prior_atlas_dice_summary, feed_dict={prior_atlas_dice: dAC}), step) summary_writer.add_summary( sess.run( prior_dae_atlas_dice_ratio_summary, feed_dict={prior_dae_atlas_dice_ratio: ratio_dice}), step) summary_writer.add_summary( sess.run(prior_dice_summary, feed_dict={prior_dice: prr}), step) # ================== # save best model so far # ================== if best_score < prr: best_score = prr best_file = os.path.join(log_dir, 'models/best_score.ckpt') saver_best_loss.save(sess, best_file, global_step=step) logging.info( 'Found new best score (%f) at step %d - Saving model.' % (best_score, step)) # ================== # dice wrt gt # ================== y_pred = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): y_pred.append( sess.run(predicted_seg, feed_dict={ images_pl: batch, training_pl: False })) y_pred = np.squeeze(np.array(y_pred)).astype(float) y_pred = np.reshape(y_pred, [-1, y_pred.shape[2], y_pred.shape[3]]) dice_wrt_gt = met.f1_score(label.flatten(), y_pred.flatten(), average=None) summary_writer.add_summary( sess.run(gt_dice_summary, feed_dict={gt_dice: np.mean(dice_wrt_gt[1:])}), step) # ================== # visualize results # ================== if step % exp_config.vis_frequency is 0: # =========================== # save checkpoint # =========================== logging.info( '=============== Saving checkkpoint at step %d ... ' % step) checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver_test_data.save(sess, checkpoint_file, global_step=step) y_pred_noisy_denoised_upscaled = utils.crop_or_pad_volume_to_size_along_x( rescale(y_pred_noisy_denoised, [ exp_config.downsampling_factor_x, exp_config.downsampling_factor_y, exp_config.downsampling_factor_z ], order=0, preserve_range=True, multichannel=False, mode='constant'), image.shape[0]).astype(np.uint8) x_norm = [] for batch in iterate_minibatches_images( image, batch_size=exp_config.batch_size): x = batch x_norm.append( sess.run(images_normalized, feed_dict={ images_pl: x, training_pl: False })) x_norm = np.squeeze(np.array(x_norm)).astype(float) x_norm = np.reshape(x_norm, [-1, x_norm.shape[2], x_norm.shape[3]]) utils_vis.save_sample_results( x=image, x_norm=x_norm, x_diff=x_norm - image, y=y_pred, y_pred_dae=y_pred_noisy_denoised_upscaled, at=np.argmax(atlas, axis=-1), gt=label, savepath=log_dir + '/results/visualize_images/step' + str(step) + '.png') # ================================================ # Part of training ops sequence: # 1. At the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders) # ================================================ sess.run(accumulated_gradients_zero_op) num_accumulation_steps = 0 # ================================================ # batches # ================================================ for batch in iterate_minibatches_images_and_downsampled_labels( images=image, batch_size=exp_config.batch_size, labels_downsampled=target_labels_for_this_epoch, batch_size_downsampled=exp_config.batch_size_downsampled): x, y = batch # =========================== # define feed dict for this iteration # =========================== feed_dict = { images_pl: x, prior_label_1hot_pl: y, learning_rate_pl: exp_config.learning_rate, training_pl: True } # ================================================ # Part of training ops sequence: # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs # ================================================ sess.run(accumulate_gradients_op, feed_dict=feed_dict) num_accumulation_steps = num_accumulation_steps + 1 step += 1 # ================================================ # Part of training ops sequence: # 3. At the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl # ================================================ sess.run( accumulated_gradients_mean_op, feed_dict={num_accumulation_steps_pl: num_accumulation_steps}) # ================================================================ # sequence of running opt ops: # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients. # ================================================================ sess.run(train_op, feed_dict=feed_dict) # ================================================================ # ================================================================ sess.close() # ================================================================ # ================================================================ gc.collect() return 0
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list=exp_config.label_list, offset=exp_config.offset, rescale_to_one=True, force_overwrite=False) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] fieldstr_train = data['field_strength_train'] labels_train = utils.fstr_to_label(fieldstr_train, exp_config.field_strength_list, exp_config.fs_label_list) ages_train = data['age_train'] if exp_config.age_ordinal_regression: ages_train = utils.age_to_ordinal_reg_format(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train) else: ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = None images_val = data['images_val'] fieldstr_val = data['field_strength_val'] labels_val = utils.fstr_to_label(fieldstr_val, exp_config.field_strength_list, exp_config.fs_label_list) ages_val = data['age_val'] if exp_config.age_ordinal_regression: ages_val = utils.age_to_ordinal_reg_format(ages_val, bins=exp_config.age_bins) else: ages_val = utils.age_to_bins(ages_val, bins=exp_config.age_bins) if exp_config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * exp_config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info('TRAINING') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) logging.info('VALIDATiON') logging.info(' - Images:') logging.info(images_val.shape) logging.info(images_val.dtype) logging.info(' - Labels:') logging.info(labels_val.shape) logging.info(labels_val.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] labels_tensor_shape = [exp_config.batch_size] if exp_config.age_ordinal_regression: ages_tensor_shape = [ exp_config.batch_size, len(exp_config.age_bins) ] else: ages_tensor_shape = [exp_config.batch_size] images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') diag_placeholder = tf.placeholder(tf.uint8, shape=labels_tensor_shape, name='labels') ages_placeholder = tf.placeholder(tf.uint8, shape=ages_tensor_shape, name='ages') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_time_placeholder = tf.placeholder(tf.bool, shape=[], name='training_time') tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. diag_logits, ages_logits = exp_config.clf_model_handle( images_placeholder, nlabels=exp_config.nlabels, training=training_time_placeholder, n_age_thresholds=len(exp_config.age_bins), bn_momentum=exp_config.bn_momentum) # Add to the Graph the Ops for loss calculation. [loss, diag_loss, age_loss, weights_norm ] = model_mt.loss(diag_logits, ages_logits, diag_placeholder, ages_placeholder, nlabels=exp_config.nlabels, weight_decay=exp_config.weight_decay, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, use_ordinal_reg=exp_config.age_ordinal_regression, ordinal_reg_weights=ordinal_reg_weights) tf.summary.scalar('loss', loss) tf.summary.scalar('diag_loss', diag_loss) tf.summary.scalar('weights_norm_term', weights_norm) if exp_config.momentum is not None: optimiser = exp_config.optimizer_handle( learning_rate=learning_rate_placeholder, momentum=exp_config.momentum) else: optimiser = exp_config.optimizer_handle( learning_rate=learning_rate_placeholder) # create a copy of all trainable variables with `0` as initial values t_vars = tf.global_variables() #tf.trainable_variables() accum_tvars = [ tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in t_vars ] # create a op to initialize all accums vars zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars] # compute gradients for a batch update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): batch_grads_vars = optimiser.compute_gradients(loss, t_vars) # collect the batch gradient into accumulated vars accum_ops = [ accum_tvar.assign_add(batch_grad_var[0]) for accum_tvar, batch_grad_var in zip(accum_tvars, batch_grads_vars) ] accum_normaliser_pl = tf.placeholder(dtype=tf.float32, name='accum_normaliser') accum_mean_op = [ accum_tvar.assign(tf.divide(accum_tvar, accum_normaliser_pl)) for accum_tvar in accum_tvars ] # apply accums gradients with tf.control_dependencies(update_ops): train_op = optimiser.apply_gradients([ (accum_tvar, batch_grad_var[1]) for accum_tvar, batch_grad_var in zip(accum_tvars, batch_grads_vars) ]) eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = model_mt.evaluation( diag_logits, ages_logits, diag_placeholder, ages_placeholder, images_placeholder, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, nlabels=exp_config.nlabels, use_ordinal_reg=exp_config.age_ordinal_regression) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. saver = tf.train.Saver(max_to_keep=3) saver_best_diag_f1 = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # Create a session for running Ops on the Graph. sess = tf.Session(config=config) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error_diag') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_diag_f1') val_f1_diag_summary = tf.summary.scalar('validation_diag_f1', val_diag_f1_score_) val_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_ages_f1') val_summary = tf.summary.merge( [val_error_summary, val_f1_diag_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error_diag') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_diag_f1') train_diag_f1_summary = tf.summary.scalar('training_diag_f1', train_diag_f1_score_) train_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_ages_f1') train_summary = tf.summary.merge( [train_error_summary, train_diag_f1_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_diag_f1_score = 0 # acum_manual = 0 #np.zeros((2,3,3,3,1,32)) for epoch in range(exp_config.max_epochs): logging.info('EPOCH %d' % epoch) sess.run(zero_ops) accum_counter = 0 for batch in iterate_minibatches( images_train, [labels_train, ages_train], batch_size=exp_config.batch_size, augmentation_function=exp_config.augmentation_function, exp_config=exp_config): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # get a batch x, [y, a] = batch # TEMPORARY HACK (to avoid incomplete batches) if y.shape[0] < exp_config.batch_size: step += 1 continue # Run accumulation feed_dict = { images_placeholder: x, diag_placeholder: y, ages_placeholder: a, learning_rate_placeholder: curr_lr, training_time_placeholder: True } _, loss_value = sess.run([accum_ops, loss], feed_dict=feed_dict) accum_counter += 1 if accum_counter == exp_config.n_accum_batches: # Average gradient over batches sess.run(accum_mean_op, feed_dict={ accum_normaliser_pl: float(exp_config.n_accum_batches) }) sess.run(train_op, feed_dict=feed_dict) # Reset all counters etc. sess.run(zero_ops) accum_counter = 0 duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % exp_config.train_eval_frequency == 0: # Evaluate against the training set logging.info('Training Data Eval:') [train_loss, train_diag_f1, train_ages_f1] = do_eval( sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, images_placeholder, diag_placeholder, ages_placeholder, training_time_placeholder, images_train, [labels_train, ages_train], batch_size=exp_config.batch_size, do_ordinal_reg=exp_config.age_ordinal_regression) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_diag_f1_score_: train_diag_f1, train_ages_f1_score_: train_ages_f1 }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: logging.info('Decrease in training error!') else: logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_diag_f1, val_ages_f1] = do_eval( sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, images_placeholder, diag_placeholder, ages_placeholder, training_time_placeholder, images_val, [labels_val, ages_val], batch_size=exp_config.batch_size, do_ordinal_reg=exp_config.age_ordinal_regression) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_diag_f1_score_: val_diag_f1, val_ages_f1_score_: val_ages_f1 }) summary_writer.add_summary(val_summary_msg, step) if val_diag_f1 >= best_diag_f1_score: best_diag_f1_score = val_diag_f1 best_file = os.path.join( log_dir, 'model_best_diag_f1.ckpt') saver_best_diag_f1.save(sess, best_file, global_step=step) logging.info( 'Found new best DIAGNOSIS F1 score on validation set! - %f - Saving model_best_diag_f1.ckpt' % val_diag_f1) if val_loss <= best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 sess.close()
def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None): nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = exp_config.nlabels image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config) saver = tf.train.Saver() init = tf.global_variables_initializer() evaluate_test_set = not gt_exists with tf.Session() as sess: sess.run(init) if not use_iter: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'model_best_dice.ckpt') else: checkpoint_path = os.path.join(model_path, 'model.ckpt-%d' % use_iter) saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob( os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info( ' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info( ' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() img = image_utils.normalise_image(img) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() if exp_config.data_mode == '2D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1]) predictions = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:, :, zz]) slice_rescaled = transform.rescale( slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') x, y = slice_rescaled.shape x_s = (x - nx) // 2 y_s = (y - ny) // 2 x_c = (nx - x) // 2 y_c = (ny - y) // 2 # Crop section of image for prediction if x > nx and y > ny: slice_cropped = slice_rescaled[x_s:x_s + nx, y_s:y_s + ny] else: slice_cropped = np.zeros((nx, ny)) if x <= nx and y > ny: slice_cropped[ x_c:x_c + x, :] = slice_rescaled[:, y_s:y_s + ny] elif x > nx and y <= ny: slice_cropped[:, y_c:y_c + y] = slice_rescaled[ x_s:x_s + nx, :] else: slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice_rescaled[:, :] # GET PREDICTION network_input = np.float32( np.tile( np.reshape(slice_cropped, (nx, ny, 1)), (batch_size, 1, 1, 1))) mask_out, logits_out = sess.run( [mask_pl, softmax_pl], feed_dict={images_pl: network_input}) prediction_cropped = np.squeeze( logits_out[0, ...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros( (x, y, num_channels)) # insert cropped region into original image again if x > nx and y > ny: slice_predictions[ x_s:x_s + nx, y_s:y_s + ny, :] = prediction_cropped else: if x <= nx and y > ny: slice_predictions[:, y_s:y_s + ny, :] = prediction_cropped[ x_c:x_c + x, :, :] elif x > nx and y <= ny: slice_predictions[ x_s:x_s + nx, :, :] = prediction_cropped[:, y_c: y_c + y, :] else: slice_predictions[:, :, :] = prediction_cropped[ x_c:x_c + x, y_c:y_c + y, :] # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize( slice_predictions, (mask.shape[0], mask.shape[1], num_channels), order=1, preserve_range=True, mode='constant') else: # This can occasionally lead to wrong volume size, therefore if gt_exists # we use the gt mask size for resizing. prediction = transform.rescale( slice_predictions, (1.0 / scale_vector[0], 1.0 / scale_vector[1], 1), order=1, preserve_range=True, multichannel=False, mode='constant') # prediction = transform.resize(slice_predictions, # (mask.shape[0], mask.shape[1], num_channels), # order=1, # preserve_range=True, # mode='constant') prediction = np.uint8( np.argmax(prediction, axis=-1)) predictions.append(prediction) prediction_arr = np.transpose( np.asarray(predictions, dtype=np.uint8), (1, 2, 0)) elif exp_config.data_mode == '3D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2], img_dat[2].structarr['pixdim'][3]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1], pixel_size[2] / exp_config.target_resolution[2]) vol_scaled = transform.rescale(img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') nz_max = exp_config.image_size[2] slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32) nz_curr = vol_scaled.shape[2] stack_from = (nz_max - nz_curr) // 2 stack_counter = stack_from x, y, z = vol_scaled.shape x_s = (x - nx) // 2 y_s = (y - ny) // 2 x_c = (nx - x) // 2 y_c = (ny - y) // 2 for zz in range(nz_curr): slice_rescaled = vol_scaled[:, :, zz] if x > nx and y > ny: slice_cropped = slice_rescaled[x_s:x_s + nx, y_s:y_s + ny] else: slice_cropped = np.zeros((nx, ny)) if x <= nx and y > ny: slice_cropped[ x_c:x_c + x, :] = slice_rescaled[:, y_s:y_s + ny] elif x > nx and y <= ny: slice_cropped[:, y_c:y_c + y] = slice_rescaled[ x_s:x_s + nx, :] else: slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice_rescaled[:, :] slice_vol[:, :, stack_counter] = slice_cropped stack_counter += 1 stack_to = stack_counter network_input = np.float32( np.reshape(slice_vol, (1, nx, ny, nz_max, 1))) start_time = time.time() mask_out, logits_out = sess.run( [mask_pl, softmax_pl], feed_dict={images_pl: network_input}) logging.info('Classified 3D: %f secs' % (time.time() - start_time)) prediction_nzs = logits_out[0, :, :, stack_from:stack_to, ...] # non-zero-slices if not prediction_nzs.shape[2] == nz_curr: raise ValueError('sizes mismatch') # ASSEMBLE BACK THE SLICES prediction_scaled = np.zeros( list(vol_scaled.shape) + [num_channels ]) # last dim is for logits classes # insert cropped region into original image again if x > nx and y > ny: prediction_scaled[x_s:x_s + nx, y_s:y_s + ny, :, ...] = prediction_nzs else: if x <= nx and y > ny: prediction_scaled[:, y_s:y_s + ny, :, ...] = prediction_nzs[ x_c:x_c + x, :, :, ...] elif x > nx and y <= ny: prediction_scaled[ x_s:x_s + nx, :, :...] = prediction_nzs[:, y_c:y_c + y, :...] else: prediction_scaled[:, :, : ...] = prediction_nzs[ x_c:x_c + x, y_c:y_c + y, :...] logging.info('Prediction_scaled mean %f' % (np.mean(prediction_scaled))) prediction = transform.resize( prediction_scaled, (mask.shape[0], mask.shape[1], mask.shape[2], num_channels), order=1, preserve_range=True, mode='constant') prediction = np.argmax(prediction, axis=-1) prediction_arr = np.asarray(prediction, dtype=np.uint8) # This is the same for 2D and 3D again if do_postprocessing: prediction_arr = image_utils.keep_largest_connected_components( prediction_arr) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError( 'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join( output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join( output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join( output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where( np.abs(prediction_arr - mask) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) diff_file_name = os.path.join( output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time / total_volumes)) return init_iteration
def main(config): # Load data data = acdc_data.load_and_maybe_process_data( input_folder=config.data_root , #data_root test_data_root preprocessing_folder=config.preprocessing_folder, mode=config.data_mode, size=config.image_size, target_resolution=config.target_resolution, force_overwrite=False ) batch_size = 1 image_tensor_shape = [batch_size] + list(config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, config) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) while True: ind = np.random.randint(data['images_test'].shape[0]) x = data['images_test'][ind,...] y = data['masks_test'][ind,...] for img in x: if config.standardize: img = image_utils.standardize_image(img) if config.normalize: img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) #x = cv2.normalize(x, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) x = image_utils.reshape_2Dimage_to_tensor(x) y = image_utils.reshape_2Dimage_to_tensor(y) logging.info('x') logging.info(x.shape) logging.info(x.dtype) logging.info(x.min()) logging.info(x.max()) plt.imshow(np.squeeze(x)) plt.gray() plt.axis('off') plt.show() logging.info('y') logging.info(y.shape) logging.info(y.dtype) feed_dict = { images_pl: x, } mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) logging.info('mask_out') logging.info(mask_out.shape) logging.info('softmax_out') logging.info(softmax_out.shape) fig = plt.figure(1) ax1 = fig.add_subplot(241) ax1.set_axis_off() ax1.imshow(np.squeeze(x), cmap='gray') ax2 = fig.add_subplot(242) ax2.set_axis_off() ax2.imshow(np.squeeze(y)) ax3 = fig.add_subplot(243) ax3.set_axis_off() ax3.imshow(np.squeeze(mask_out)) ax1.title.set_text('a') ax2.title.set_text('b') ax3.title.set_text('c') ax5 = fig.add_subplot(245) ax5.set_axis_off() ax5.imshow(np.squeeze(softmax_out[...,0])) ax6 = fig.add_subplot(246) ax6.set_axis_off() ax6.imshow(np.squeeze(softmax_out[...,1])) ax7 = fig.add_subplot(247) ax7.set_axis_off() ax7.imshow(np.squeeze(softmax_out[...,2])) ax8 = fig.add_subplot(248) ax8.set_axis_off() ax8.imshow(np.squeeze(softmax_out[...,3])) #cmap=cm.Blues ax5.title.set_text('d') ax6.title.set_text('e') ax7.title.set_text('f') ax8.title.set_text('g') plt.gray() plt.show() logging.info('mask_out type') logging.info(mask_out.dtype) data.close()