def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD): torch.cuda.empty_cache() history = [] # Set up cutom optimizer with weight decay optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay) # Set up one-cycle learning rate scheduler sched = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader)) for epoch in range(epochs): # Training Phase model.train() train_losses = [] lrs = [] for batch in train_loader: loss = model.training_step(batch) train_losses.append(loss) loss.backward() # Gradient clipping if grad_clip: nn.utils.clip_grad_value_(model.parameters(), grad_clip) optimizer.step() optimizer.zero_grad() # Record & update learning rate lrs.append(get_lr(optimizer)) sched.step() # Validation phase result = evaluate(model, val_loader) result['train_loss'] = torch.stack(train_losses).mean().item() result['lrs'] = lrs model.epoch_end(epoch, result) history.append(result) return history
def run_uda_training(log_dir, images_sd_tr, labels_sd_tr, images_sd_vl, labels_sd_vl, images_td_tr, images_td_vl): # ================================================================ # reset the graph built so far and build a new TF graph # ================================================================ tf.reset_default_graph() with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_num_uda) np.random.seed(exp_config.run_num_uda) # ================================================================ # create placeholders - segmentation net # ================================================================ images_sd_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_sd') images_td_pl = tf.placeholder(tf.float32, shape=[exp_config.batch_size] + list(exp_config.image_size) + [1], name='images_td') labels_sd_pl = tf.placeholder(tf.uint8, shape=[exp_config.batch_size] + list(exp_config.image_size), name='labels_sd') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # ================================================================ images_sd_normalized, _ = model.normalize(images_sd_pl, exp_config, training_pl, scope_reuse=False) images_td_normalized, _ = model.normalize(images_td_pl, exp_config, training_pl, scope_reuse=True) # ================================================================ # get logit predictions from the segmentation network # ================================================================ predicted_seg_sd_logits, _, _ = model.predict_i2l(images_sd_normalized, exp_config, training_pl, scope_reuse=False) # ================================================================ # get all features from the segmentation network # ================================================================ seg_features_sd = model.get_all_features(images_sd_normalized, exp_config, scope_reuse=True) seg_features_td = model.get_all_features(images_td_normalized, exp_config, scope_reuse=True) # ================================================================ # resize all features to the same size # ================================================================ images_sd_features_resized = model.resize_features( [images_sd_normalized] + list(seg_features_sd), (64, 64), 'resize_sd_xfeat') images_td_features_resized = model.resize_features( [images_td_normalized] + list(seg_features_td), (64, 64), 'resize_td_xfeat') # ================================================================ # discriminator on features # ================================================================ d_logits_sd = model.discriminator(images_sd_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=False) d_logits_td = model.discriminator(images_td_features_resized, exp_config, training_pl, scope_name='discriminator', scope_reuse=True) # ================================================================ # add ops for calculation of the discriminator loss # ================================================================ d_loss_sd = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_logits_sd), logits=d_logits_sd) d_loss_td = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(d_logits_td), logits=d_logits_td) loss_d_op = tf.reduce_mean(d_loss_sd + d_loss_td) tf.summary.scalar('tr_losses/loss_discriminator', loss_d_op) # ================================================================ # add ops for calculation of the adversarial loss that tries to get domain invariant features in the normalized image space # ================================================================ loss_g_op = model.loss_invariance(d_logits_td) tf.summary.scalar('tr_losses/loss_invariant_features', loss_g_op) # ================================================================ # add ops for calculation of the supervised segmentation loss # ================================================================ loss_seg_op = model.loss(predicted_seg_sd_logits, labels_sd_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('tr_losses/loss_segmentation', loss_seg_op) # ================================================================ # total training loss for uda # ================================================================ loss_total_op = loss_seg_op + exp_config.lambda_uda * loss_g_op tf.summary.scalar('tr_losses/loss_total_uda', loss_total_op) # ================================================================ # merge all summaries # ================================================================ summary_scalars = tf.summary.merge_all() # ================================================================ # divide the vars into segmentation network, normalization network and the discriminator network # ================================================================ i2l_vars = [] normalization_vars = [] discriminator_vars = [] for v in tf.global_variables(): var_name = v.name if 'image_normalizer' in var_name: normalization_vars.append(v) i2l_vars.append( v ) # the normalization vars also need to be restored from the pre-trained i2l mapper elif 'i2l_mapper' in var_name: i2l_vars.append(v) elif 'discriminator' in var_name: discriminator_vars.append(v) # ================================================================ # add optimization ops # ================================================================ train_i2l_op = model.training_step( loss_total_op, i2l_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) train_discriminator_op = model.training_step( loss_d_op, discriminator_vars, exp_config.optimizer_handle, learning_rate=exp_config.learning_rate) # ================================================================ # add ops for model evaluation # ================================================================ eval_loss = model.evaluation_i2l_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, d_logits_td, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # add ops for adding image summary to tensorboard # ================================================================ summary_images = model.write_image_summary_uda_invariant_features( predicted_seg_sd_logits, labels_sd_pl, images_sd_pl, exp_config.nlabels) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ if exp_config.debug: print('creating summary op...') # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ if exp_config.debug: logging.info( 'Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a file writer object # This writes Summary protocol buffers to event files. # https://github.com/tensorflow/docs/blob/r1.12/site/en/api_docs/python/tf/summary/FileWriter.md # The FileWriter class provides a mechanism to create an event file in a given directory and add summaries and events to it. # The class updates the file contents asynchronously. # This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training. # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # create savers # ================================================================ saver = tf.train.Saver(var_list=i2l_vars) saver_lowest_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3) # ================================================================ # summaries of the validation errors # ================================================================ vl_error_seg = tf.placeholder(tf.float32, shape=[], name='vl_error_seg') vl_error_seg_summary = tf.summary.scalar('validation/loss_seg', vl_error_seg) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_error_invariance = tf.placeholder(tf.float32, shape=[], name='vl_error_invariance') vl_error_invariance_summary = tf.summary.scalar( 'validation/loss_invariance', vl_error_invariance) vl_error_total = tf.placeholder(tf.float32, shape=[], name='vl_error_total') vl_error_total_summary = tf.summary.scalar('validation/loss_total', vl_error_total) vl_summary = tf.summary.merge([ vl_error_seg_summary, vl_dice_summary, vl_error_invariance_summary, vl_error_total_summary ]) # ================================================================ # summaries of the training errors # ================================================================ tr_error_seg = tf.placeholder(tf.float32, shape=[], name='tr_error_seg') tr_error_seg_summary = tf.summary.scalar('training/loss_seg', tr_error_seg) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_error_invariance = tf.placeholder(tf.float32, shape=[], name='tr_error_invariance') tr_error_invariance_summary = tf.summary.scalar( 'training/loss_invariance', tr_error_invariance) tr_error_total = tf.placeholder(tf.float32, shape=[], name='tr_error_total') tr_error_total_summary = tf.summary.scalar('training/loss_total', tr_error_total) tr_summary = tf.summary.merge([ tr_error_seg_summary, tr_dice_summary, tr_error_invariance_summary, tr_error_total_summary ]) # ================================================================ # freeze the graph before execution # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ if exp_config.debug: logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of uninitialized variables # ================================================================ uninit_variables = sess.run(uninit_vars) if exp_config.debug: logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') for v in uninit_variables: print(v) # ================================================================ # Restore the segmentation network parameters and the pre-trained i2i mapper parameters # ================================================================ if exp_config.train_from_scratch is False: logging.info( '============================================================') path_to_model = sys_config.log_root + exp_config.expname_i2l + '/models/' checkpoint_path = utils.get_latest_model_checkpoint_path( path_to_model, 'best_dice.ckpt') logging.info('Restoring the trained parameters from %s...' % checkpoint_path) saver_lowest_loss.restore(sess, checkpoint_path) # ================================================================ # run training steps # ================================================================ step = 0 lowest_loss = 10000.0 validation_total_loss_list = [] while (step < exp_config.max_steps): # ================================================ # batches # ================================================ for batch in iterate_minibatches(images_sd=images_sd_tr, labels_sd=labels_sd_tr, images_td=images_td_tr, batch_size=exp_config.batch_size): x_sd, y_sd, x_td = batch # =========================== # define feed dict for this iteration # =========================== feed_dict = { images_sd_pl: x_sd, labels_sd_pl: y_sd, images_td_pl: x_td, training_pl: True } # ================================================ # update i2l and D successively # ================================================ sess.run(train_i2l_op, feed_dict=feed_dict) sess.run(train_discriminator_op, feed_dict=feed_dict) # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( '============== Updating summary at step %d ' % step) summary_writer.add_summary( sess.run(summary_scalars, feed_dict=feed_dict), step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('============== Training Data Eval:') train_loss_seg, train_dice, train_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_tr, labels_sd_tr, images_td_tr, exp_config.batch_size) # total training loss train_total_loss = train_loss_seg + exp_config.lambda_uda * train_loss_invariance # =========================== # update tensorboard summary of scalars # =========================== tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error_seg: train_loss_seg, tr_dice: train_dice, tr_error_invariance: train_loss_invariance, tr_error_total: train_total_loss }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: logging.info( '============== Periodically saving checkpoint:') checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('============== Validation Data Eval:') val_loss_seg, val_dice, val_loss_invariance = do_eval( sess, eval_loss, images_sd_pl, labels_sd_pl, images_td_pl, training_pl, images_sd_vl, labels_sd_vl, images_td_vl, exp_config.batch_size) # total val loss val_total_loss = val_loss_seg + exp_config.lambda_uda * val_loss_invariance validation_total_loss_list.append(val_total_loss) # =========================== # update tensorboard summary of scalars # =========================== vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error_seg: val_loss_seg, vl_dice: val_dice, vl_error_invariance: val_loss_invariance, vl_error_total: val_total_loss }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # update tensorboard summary of images # =========================== summary_writer.add_summary( sess.run(summary_images, feed_dict={ images_sd_pl: x_sd, labels_sd_pl: y_sd, training_pl: False }), step) summary_writer.flush() # =========================== # save model if the val dice is the best yet # =========================== window_length = 5 if len(validation_total_loss_list) < window_length + 1: expo_moving_avg_loss_value = validation_total_loss_list[ -1] else: expo_moving_avg_loss_value = utils.exponential_moving_average( validation_total_loss_list, window=window_length)[-1] if expo_moving_avg_loss_value < lowest_loss: lowest_loss = val_total_loss lowest_loss_file = os.path.join( log_dir, 'models/lowest_loss.ckpt') saver_lowest_loss.save(sess, lowest_loss_file, global_step=step) logging.info( '******* SAVED MODEL at NEW BEST AVERAGE LOSS on VALIDATION SET at step %d ********' % step) # ================================================ # increment step # ================================================ step += 1 # ================================================================ # close tf session # ================================================================ sess.close() return 0
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imtr, gttr = data_hcp.load_data(sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', 1, exp_config.train_size + 1) imvl, gtvl = data_hcp.load_data( sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', exp_config.train_size + 1, exp_config.train_size + exp_config.val_size + 1) logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # Define the segmentation network here # ================================================================ mask = ~(imtr == 0).all(axis=(1, 2)) imtr = imtr[mask] gttr = gttr[mask] mask = ~(imvl == 0).all(axis=(1, 2)) imvl = imvl[mask] gtvl = gtvl[mask] # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') # Placeholders for the images and labels image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_pl = tf.compat.v1.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_tensor_shape_segmentation = [exp_config.batch_size] + list( exp_config.image_size) labels_tensor_shape_self_supervised = [exp_config.batch_size] + [ exp_config.num_rotation_values ] labels_segmentation_pl = tf.compat.v1.placeholder( tf.uint8, shape=labels_tensor_shape_segmentation, name='labels_segmentation') labels_self_supervised_pl = tf.compat.v1.placeholder( tf.float16, shape=labels_tensor_shape_self_supervised, name='labels_self_supervised') # Placeholders for the learning rate and to indicate whether the model is being trained or tested learning_rate_pl = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.compat.v1.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # Define the segmentation network here # ================================================================ logits_segmentation = exp_config.model_handle_segmentation( images_pl, image_tensor_shape, training_pl, exp_config.nlabels) # ================================================================ # determine trainable variables # ================================================================ segmentation_vars = [] self_supervised_vars = [] test_time_opt_vars = [] for v in tf.compat.v1.trainable_variables(): var_name = v.name if 'rotation' in var_name: self_supervised_vars.append(v) elif 'segmentation' in var_name: segmentation_vars.append(v) elif 'normalization' in var_name: test_time_opt_vars.append(v) if exp_config.debug is True: logging.info('================================') logging.info('List of trainable variables in the graph:') for v in tf.compat.v1.trainable_variables(): logging.info(v.name) logging.info('================================') logging.info('List of all segmentation variables:') for v in segmentation_vars: logging.info(v.name) logging.info('================================') logging.info('List of all self-supervised variables:') for v in self_supervised_vars: logging.info(v.name) logging.info('================================') logging.info('List of all test time variables:') for v in test_time_opt_vars: logging.info(v.name) # ================================================================ # Add ops for calculation of the training loss # ================================================================ loss_segmentation = model.loss(logits_segmentation, labels_segmentation_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) tf.compat.v1.summary.scalar('loss_segmentation', loss_segmentation) # ================================================================ # Add optimization ops. # ================================================================ train_op_train_time = model.training_step( loss_segmentation, segmentation_vars, exp_config.optimizer_handle_training, learning_rate_pl) # ================================================================ # Add ops for model evaluation # ================================================================ eval_loss_segmentation = model.evaluation( logits_segmentation, labels_segmentation_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) # ================================================================ # Build the summary Tensor based on the TF collection of Summaries. # ================================================================ summary = tf.compat.v1.summary.merge_all() # ================================================================ # Add init ops # ================================================================ init_g = tf.compat.v1.global_variables_initializer() init_l = tf.compat.v1.local_variables_initializer() # ================================================================ # Find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.compat.v1.report_uninitialized_variables() # ================================================================ # create savers for each domain # ================================================================ max_to_keep = 15 saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep) saver_best_da = tf.compat.v1.train.Saver() # ================================================================ # Create session # ================================================================ config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.compat.v1.Session(config=config) # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the training errors # ================================================================ tr_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='tr_error_seg') tr_error_summary_seg = tf.compat.v1.summary.scalar( 'training/loss_seg', tr_error_seg) tr_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='tr_dice_seg') tr_dice_summary_seg = tf.compat.v1.summary.scalar( 'training/dice_seg', tr_dice_seg) tr_summary_seg = tf.compat.v1.summary.merge( [tr_error_summary_seg, tr_dice_summary_seg]) # ================================================================ # summaries of the validation errors # ================================================================ vl_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='vl_error_seg') vl_error_summary_seg = tf.compat.v1.summary.scalar( 'validation/loss_seg', vl_error_seg) vl_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='vl_dice_seg') vl_dice_summary_seg = tf.compat.v1.summary.scalar( 'validation/dice_seg', vl_dice_seg) vl_summary_seg = tf.compat.v1.summary.merge( [vl_error_summary_seg, vl_dice_summary_seg]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.compat.v1.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_g) sess.run(init_l) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: logging.info(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_da = 0 # ================================================================ # run training epochs # ================================================================ for epoch in range(exp_config.max_epochs): logging.info( '============================================================') logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size): curr_lr = exp_config.learning_rate start_time = time.time() x, y_seg, y_ss = batch if step == 0: x_s = x y_seg_s = y_seg y_ss_s = y_ss # =========================== # avoid incomplete batches # =========================== if y_seg.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_segmentation_pl: y_seg, labels_self_supervised_pl: y_ss, learning_rate_pl: curr_lr, training_pl: True } # =========================== # update vars # =========================== _, loss_value = sess.run( [train_op_train_time, loss_segmentation], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.2f (%.3f sec for the last step)' % (step + 1, loss_value, duration)) # =========================== # print values of a parameter (to debug) # =========================== if exp_config.debug is True: var_value = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES )[0].eval( session=sess ) # can add name if you only want parameters with a specific name logging.info('value of one of the parameters %f' % var_value[0, 0, 0, 0]) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss_seg, train_dice_seg ] = do_eval(sess, eval_loss_segmentation, images_pl, labels_segmentation_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary_seg, feed_dict={ tr_error_seg: train_loss_seg, tr_dice_seg: train_dice_seg }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically # =========================== if step % exp_config.val_eval_frequency == 0: # =========================== # Evaluate against the validation set of each domain # =========================== logging.info('Validation Data Eval:') [val_loss_seg, val_dice_seg] = do_eval(sess, eval_loss_segmentation, images_pl, labels_segmentation_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary_seg, feed_dict={ vl_error_seg: val_loss_seg, vl_dice_seg: val_dice_seg }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice/accuracy is the best yet # =========================== if val_dice_seg > best_da: best_da = val_dice_seg best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_da.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice_seg) step += 1 sess.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) if hasattr(exp_config, 'train_on_all_data'): train_on_all_data = exp_config.train_on_all_data else: train_on_all_data = False # Load data data = acdc_data.load_and_maybe_process_data( input_folder=sys_config.data_root, preprocessing_folder=sys_config.preproc_folder, mode=exp_config.data_mode, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False, split_test_train=(not train_on_all_data)) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] labels_train = data['masks_train'] if not train_on_all_data: images_val = data['images_test'] labels_val = data['masks_test'] if exp_config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * exp_config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[]) training_pl = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_pl) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_pl, exp_config, training=training_pl) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm] = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay ) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if exp_config.momentum is not None: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_pl, momentum=exp_config.momentum) else: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_pl) # Add the Op to compare the logits to the labels during evaluation. eval_loss = model.evaluation(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. if train_on_all_data: max_to_keep = None else: max_to_keep = 5 saver = tf.train.Saver(max_to_keep=max_to_keep) saver_best_dice = tf.train.Saver() saver_best_xent = tf.train.Saver() # Create a session for running Ops on the Graph. config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not define it the default device, let it execute in another. sess = tf.Session(config=config) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge([val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 for epoch in range(exp_config.max_epochs): logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches( images_train, labels_train, batch_size=exp_config.batch_size, augment_batch=exp_config.augment_batch): # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the # training speed. However, be aware that currently an exception inside this loop may not be caught. # The batch generator may just continue running silently without warning eventhough the code has # crashed. # for batch in BackgroundGenerator(iterate_minibatches(images_train, # labels_train, # batch_size=exp_config.batch_size, # augment_batch=exp_config.augment_batch)): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches if y.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss, train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_train, labels_train, exp_config.batch_size) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_dice_: train_dice }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: logging.info('Decrease in training error!') else: logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. if not train_on_all_data: # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_val, labels_val, exp_config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'model_best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - %f - Saving model_best_dice.ckpt' % val_dice) if val_loss < best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 sess.close() data.close()
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=1040, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] if exp_config.train_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'IXI': logging.info('Reading IXI images...') logging.info('Data root directory: ' + sys_config.orig_data_root_ixi) data_brain_train = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=0, idx_end=12, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=12, idx_end=17, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module will be adapted for each test image # ================================================================ images_normalized, _ = model.normalize(images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, _, _ = model.predict_i2l(images_normalized, exp_config, training_pl=training_pl) print('shape of inputs: ', images_pl.shape) # (batch_size, 256, 256, 1) print('shape of logits: ', logits.shape) # (batch_size, 256, 256, nlabels) # ================================================================ # create a list of all vars that must be optimized wrt # ================================================================ i2l_vars = [] for v in tf.trainable_variables(): i2l_vars.append(v) # ================================================================ # add ops for calculation of the supervised training loss # ================================================================ loss_op = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('loss', loss_op) # ================================================================ # add optimization ops. # Create different ops according to the variables that must be trained # ================================================================ print('creating training op...') train_op = model.training_step(loss_op, i2l_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars=True) # ================================================================ # add ops for model evaluation # ================================================================ print('creating eval op...') eval_loss = model.evaluation_i2l(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ print('creating summary op...') summary = tf.summary.merge_all() # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create saver # ================================================================ saver = tf.train.Saver(max_to_keep=10) saver_best_dice = tf.train.Saver(max_to_keep=3) # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of all variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of all variables:') for v in tf.trainable_variables(): print(v.name) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_dice = 0 # ================================================================ # run training epochs # ================================================================ while (step < exp_config.max_steps): if step % 1000 is 0: logging.info( '============================================================' ) logging.info('step %d' % step) # ================================================ # batches # ================================================ for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size, train_or_eval='train'): curr_lr = exp_config.learning_rate start_time = time.time() x, y = batch # =========================== # avoid incomplete batches # =========================== if y.shape[0] < exp_config.batch_size: step += 1 continue # =========================== # create the feed dict for this training iteration # =========================== feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } # =========================== # opt step # =========================== _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.3f (%.3f sec for the last step)' % (step + 1, loss, duration)) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') train_loss, train_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('Validation Data Eval:') val_loss, val_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()
def run_training(): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # Determine the data set # ============================ target_data = True # ============================ # Load data # ============================ if target_data: hcp = True if hcp: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imts, gtts = data_hcp.load_data(sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T2', 1, exp_config.test_size + 1) logging.info('Test Images: %s' % str( imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Test Labels: %s' % str( gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') else: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading ABIDE caltech...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide) imts, gtts = data_abide.load_data(sys_config.orig_data_root_abide, sys_config.preproc_folder_abide, 1, exp_config.test_size + 1) logging.info('Test Images: %s' % str( imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Test Labels: %s' % str( gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') else: logging.info( '============================================================') logging.info('Loading data...') logging.info('Reading HCP - 3T - T1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) imts, gtts = data_hcp.load_data( sys_config.orig_data_root_hcp, sys_config.preproc_folder_hcp, 'T1', exp_config.train_size + exp_config.val_size + 1, exp_config.train_size + exp_config.val_size + exp_config.test_size + 1) logging.info( 'Test Images: %s' % str(imts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Test Labels: %s' % str(gtts.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( '============================================================') # ============================ # Remove exclusively black images # ============================ mask = ~(imts == 0).all(axis=(1, 2)) imts = imts[mask] gtts = gtts[mask] # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') # Placeholders for the images and labels image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] images_pl = tf.compat.v1.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_tensor_shape_segmentation = [exp_config.batch_size] + list( exp_config.image_size) labels_tensor_shape_self_supervised = [exp_config.batch_size] + [ exp_config.num_rotation_values ] labels_segmentation_pl = tf.compat.v1.placeholder( tf.uint8, shape=labels_tensor_shape_segmentation, name='labels_segmentation') labels_self_supervised_pl = tf.compat.v1.placeholder( tf.float16, shape=labels_tensor_shape_self_supervised, name='labels_self_supervised') # Placeholders for the learning rate and to indicate whether the model is being trained or tested learning_rate_pl = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.compat.v1.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # Define the image normalization function - these parameters will be updated for each test image # ================================================================ image_normalized = exp_config.model_handle_normalizer( images_pl, image_tensor_shape) # ================================================================ # Define the segmentation network here # ================================================================ logits_segmentation = exp_config.model_handle_segmentation( image_normalized, image_tensor_shape, training_pl, exp_config.nlabels) # ================================================================ # Define the self-supervised network here # ================================================================ logits_self_supervised = exp_config.model_handle_rotation( image_normalized, image_tensor_shape, training_pl) # ================================================================ # determine trainable variables # ================================================================ segmentation_vars = [] self_supervised_vars = [] test_time_opt_vars = [] for v in tf.compat.v1.trainable_variables(): var_name = v.name if 'rotation' in var_name: self_supervised_vars.append(v) elif 'segmentation' in var_name: segmentation_vars.append(v) elif 'normalization' in var_name: test_time_opt_vars.append(v) if exp_config.debug is True: logging.info('================================') logging.info('List of trainable variables in the graph:') for v in tf.compat.v1.trainable_variables(): logging.info(v.name) logging.info('================================') logging.info('List of all segmentation variables:') for v in segmentation_vars: logging.info(v.name) logging.info('================================') logging.info('List of all self-supervised variables:') for v in self_supervised_vars: logging.info(v.name) logging.info('================================') logging.info('List of all test time variables:') for v in test_time_opt_vars: logging.info(v.name) # ================================================================ # Add ops for calculation of the training loss # ================================================================ loss_self_supervised = tf.reduce_mean( exp_config.loss_handle_self_supervised(labels_self_supervised_pl, logits_self_supervised)) tf.compat.v1.summary.scalar('loss_self_supervised', loss_self_supervised) # ================================================================ # Add optimization ops. # ================================================================ train_op_test_time = model.training_step( loss_self_supervised, test_time_opt_vars, exp_config.optimizer_handle_test_time, learning_rate_pl) # ================================================================ # Add ops for model evaluation # ================================================================ eval_loss_segmentation = model.evaluation( logits_segmentation, labels_segmentation_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) eval_loss_self_supervised = model.evaluation_rotate( logits_self_supervised, labels_self_supervised_pl) # ================================================================ # Build the summary Tensor based on the TF collection of Summaries. # ================================================================ summary = tf.compat.v1.summary.merge_all() # ================================================================ # Add init ops # ================================================================ init_g = tf.compat.v1.global_variables_initializer() init_l = tf.compat.v1.local_variables_initializer() # ================================================================ # Find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.compat.v1.report_uninitialized_variables() # ================================================================ # create savers for each domain # ================================================================ max_to_keep = 15 saver = tf.compat.v1.train.Saver(max_to_keep=max_to_keep) saver_best_da = tf.compat.v1.train.Saver() saver_seg = tf.compat.v1.train.Saver(var_list=segmentation_vars) saver_ss = tf.compat.v1.train.Saver(var_list=self_supervised_vars) # ================================================================ # Create session # ================================================================ config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.compat.v1.Session(config=config) # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the test errors # ================================================================ ts_error_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_error_seg') ts_error_summary_seg = tf.compat.v1.summary.scalar( 'test/loss_seg', ts_error_seg) ts_dice_seg = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_dice_seg') ts_dice_summary_seg = tf.compat.v1.summary.scalar( 'test/dice_seg', ts_dice_seg) ts_summary_seg = tf.compat.v1.summary.merge( [ts_error_summary_seg, ts_dice_summary_seg]) ts_error_ss = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_error_ss') ts_error_summary_ss = tf.compat.v1.summary.scalar( 'test/loss_ss', ts_error_ss) ts_acc_ss = tf.compat.v1.placeholder(tf.float32, shape=[], name='ts_acc_ss') ts_acc_summary_ss = tf.compat.v1.summary.scalar( 'test/acc_ss', ts_acc_ss) ts_summary_ss = tf.compat.v1.summary.merge( [ts_error_summary_ss, ts_acc_summary_ss]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.compat.v1.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_g) sess.run(init_l) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: logging.info(v) # ================================================================ # restore shared weights # ================================================================ logging.info( '============================================================') logging.info('Restore segmentation...') model_path = os.path.join(sys_config.log_root, 'Initial_training_segmentation') checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'models/best_dice.ckpt') logging.info('Restroring session from: %s' % checkpoint_path) saver_seg.restore(sess, checkpoint_path) logging.info( '============================================================') logging.info('Restore self-supervised...') model_path = os.path.join(sys_config.log_root, 'Initial_training_rotation') checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'models/best_dice.ckpt') logging.info('Restroring session from: %s' % checkpoint_path) saver_ss.restore(sess, checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_da = 0 # ================================================================ # run training epochs # ================================================================ for epoch in range(exp_config.max_epochs): logging.info( '============================================================') logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches(imts, gtts, batch_size=exp_config.batch_size): curr_lr = exp_config.learning_rate start_time = time.time() x, y_seg, y_ss = batch # =========================== # avoid incomplete batches # =========================== if y_seg.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_self_supervised_pl: y_ss, learning_rate_pl: curr_lr, training_pl: True } # =========================== # update vars # =========================== _, loss_value = sess.run( [train_op_test_time, loss_self_supervised], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.2f (%.3f sec for the last step)' % (step + 1, loss_value, duration)) # =========================== # print values of a parameter (to debug) # =========================== if exp_config.debug is True: var_value = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES )[0].eval( session=sess ) # can add name if you only want parameters with a specific name logging.info('value of one of the parameters %f' % var_value[0, 0, 0, 0]) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # compute the loss on the entire test set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Test Data Eval:') [test_loss_seg, test_dice_seg, test_loss_ss, test_acc_ss ] = do_eval(sess, eval_loss_segmentation, eval_loss_self_supervised, images_pl, labels_segmentation_pl, labels_self_supervised_pl, training_pl, imts, gtts, exp_config.batch_size) ts_summary_msg = sess.run(ts_summary_seg, feed_dict={ ts_error_seg: test_loss_seg, ts_dice_seg: test_dice_seg }) summary_writer.add_summary(ts_summary_msg, step) ts_summary_msg = sess.run(ts_summary_ss, feed_dict={ ts_error_ss: test_loss_ss, ts_acc_ss: test_acc_ss }) summary_writer.add_summary(ts_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) step += 1 sess.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) train_on_all_data = config.train_on_all_data # Load data data = read_data.load_and_maybe_process_data( input_folder=config.data_root, preprocessing_folder=config.preprocessing_folder, mode=config.data_mode, size=config.image_size, target_resolution=config.target_resolution, force_overwrite=False) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] labels_train = data['masks_train'] if not train_on_all_data: images_val = data['images_val'] labels_val = data['masks_val'] if config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) # if config.prob: #if prob is not 0 # logging.info('Before data_augmentation the number of training images is:') # logging.info(images_train.shape[0]) # #augmentation # image_aug, label_aug = aug.augmentation_function(images_train,labels_train) #num_aug = image_aug.shape[0] # id images augmented will be b'0.0' #id_aug = np.zeros([num_aug,]).astype('|S9') #concatenate #id_train = np.concatenate((id__train,id_aug)) # images_train = np.concatenate((images_train,image_aug)) # labels_train = np.concatenate((labels_train,label_aug)) # logging.info('After data_augmentation the number of training images is:') # logging.info(images_train.shape[0]) # else: # logging.info('No data_augmentation. Number of training images is:') # logging.info(images_train.shape[0]) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [config.batch_size] + list( config.image_size) + [1] mask_tensor_shape = [config.batch_size] + list(config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[]) training_pl = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_pl) # Build a Graph that computes predictions from the inference model. if (config.experiment_name == 'unet2D_valid' or config.experiment_name == 'unet2D_same' or config.experiment_name == 'unet2D_same_mod' or config.experiment_name == 'unet2D_light' or config.experiment_name == 'Dunet2D_same_mod' or config.experiment_name == 'Dunet2D_same_mod2' or config.experiment_name == 'Dunet2D_same_mod3'): logits = model.inference(images_pl, config, training=training_pl) elif config.experiment_name == 'ENet': with slim.arg_scope( model_structure.ENet_arg_scope(weight_decay=2e-4)): logits = model_structure.ENet( images_pl, num_classes=config.nlabels, batch_size=config.batch_size, is_training=True, reuse=None, num_initial_blocks=1, stage_two_repeat=2, skip_connections=config.skip_connections) else: logging.warning('invalid experiment_name!') logging.info('images_pl shape') logging.info(images_pl.shape) logging.info('labels_pl shape') logging.info(labels_pl.shape) logging.info('logits shape:') logging.info(logits.shape) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm] = model.loss(logits, labels_pl, nlabels=config.nlabels, loss_type=config.loss_type, weight_decay=config.weight_decay ) # second output is unregularised loss # record how Total loss and weight decay change over time tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if config.momentum is not None: train_op = model.training_step(loss, config.optimizer_handle, learning_rate_pl, momentum=config.momentum) else: train_op = model.training_step(loss, config.optimizer_handle, learning_rate_pl) # Add the Op to compare the logits to the labels during evaluation. # loss and dice on a minibatch eval_loss = model.evaluation(logits, labels_pl, images_pl, nlabels=config.nlabels, loss_type=config.loss_type) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. if train_on_all_data: max_to_keep = None else: max_to_keep = 5 saver = tf.train.Saver(max_to_keep=max_to_keep) saver_best_dice = tf.train.Saver() saver_best_xent = tf.train.Saver() # Create a session for running Ops on the Graph. configP = tf.ConfigProto() configP.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go configP.allow_soft_placement = True # If a operation is not define it the default device, let it execute in another. sess = tf.Session(config=configP) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge([val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 for epoch in range(config.max_epochs): logging.info('EPOCH %d' % epoch) for batch in iterate_minibatches( images_train, labels_train, batch_size=config.batch_size, augment_batch=config.augment_batch): if config.warmup_training: if step < 50: curr_lr = config.learning_rate / 10.0 elif step == 50: curr_lr = config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches) if y.shape[0] < config.batch_size: step += 1 continue feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 20 == 0: # Print status to stdout. logging.info('Step %d: loss = %.3f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss, train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_train, labels_train, config.batch_size) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_dice_: train_dice }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if config.schedule_lr and loss_gradient < config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: no_improvement_counter = 0 logging.info('Decrease in training error!') else: no_improvement_counter = no_improvement_counter + 1 logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. if not train_on_all_data: # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_val, labels_val, config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'model_best_dice.ckpt') filelist = glob.glob( os.path.join(log_dir, 'model_best_dice*')) for file in filelist: os.remove(file) saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - %f - Saving model_best_dice.ckpt' % val_dice) if val_loss < best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') filelist = glob.glob( os.path.join(log_dir, 'model_best_xent*')) for file in filelist: os.remove(file) saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 # end epoch if (epoch + 1) % config.epoch_freq == 0: curr_lr = curr_lr * 0.98 logging.info('Learning rate: %f' % curr_lr) sess.close() data.close()
def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 # Load data base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles( scribble_file=sys_config.project_root + exp_config.scribble_data, target_folder=log_dir, percent_full_sup=exp_config.percent_full_sup, scr_ratio=exp_config.length_ratio) #wrap everything from this point onwards in a try-except to catch keyboard interrupt so #can control h5py closing data try: loaded_previous_recursion = False start_epoch = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) except: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'recursion_{}_model.ckpt'.format(recursion - 1)) loaded_previous_recursion = True logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval start_epoch = int( init_step / (len(base_data['images_train']) / exp_config.batch_size)) logging.info('Latest step was: %d' % init_step) logging.info('Continuing with epoch: %d' % start_epoch) except: logging.warning( '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 start_epoch = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) if loaded_previous_recursion: logging.info( "Data file exists for recursion {} " "but checkpoints only present up to recursion {}".format( recursion, recursion - 1)) logging.info("Likely means postprocessing was terminated") recursion_data = acdc_data.load_different_recursion( recursion_data, -1) recursion -= 1 # load images and validation data images_train = np.array(base_data['images_train']) scribbles_train = np.array(base_data['scribbles_train']) images_val = np.array(base_data['images_test']) labels_val = np.array(base_data['masks_test']) # if exp_config.use_data_fraction: # num_images = images_train.shape[0] # new_last_index = int(float(num_images)*exp_config.use_data_fraction) # # logging.warning('USING ONLY FRACTION OF DATA!') # logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) # images_train = images_train[0:new_last_index,...] # labels_train = labels_train[0:new_last_index,...] logging.info('Data summary:') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) #logging.info(' - Labels:') #logging.info(labels_train.shape) #logging.info(labels_train.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[]) training_time_placeholder = tf.placeholder(tf.bool, shape=[]) tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. logits = model.inference(images_placeholder, exp_config.model_handle, training=training_time_placeholder, nlabels=exp_config.nlabels) # Add to the Graph the Ops for loss calculation. [loss, _, weights_norm ] = model.loss(logits, labels_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weight_decay=exp_config.weight_decay ) # second output is unregularised loss tf.summary.scalar('loss', loss) tf.summary.scalar('weights_norm_term', weights_norm) # Add to the Graph the Ops that calculate and apply gradients. if exp_config.momentum is not None: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_placeholder, momentum=exp_config.momentum) else: train_op = model.training_step(loss, exp_config.optimizer_handle, learning_rate_placeholder) # Add the Op to compare the logits to the labels during evaluation. # eval_loss = model.evaluation(logits, # labels_placeholder, # images_placeholder, # nlabels=exp_config.nlabels, # loss_type=exp_config.loss_type, # weak_supervision=True, # cnn_threshold=exp_config.cnn_threshold, # include_bg=True) eval_val_loss = model.evaluation( logits, labels_placeholder, images_placeholder, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type, weak_supervision=True, cnn_threshold=exp_config.cnn_threshold, include_bg=False) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. # Only keep two checkpoints, as checkpoints are kept for every recursion # and they can be 300MB + saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice') val_dice_summary = tf.summary.scalar('validation_dice', val_dice_) val_summary = tf.summary.merge( [val_error_summary, val_dice_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice') train_dice_summary = tf.summary.scalar('training_dice', train_dice_) train_summary = tf.summary.merge( [train_error_summary, train_dice_summary]) # Run the Op to initialize the variables. sess.run(init) # Restore session # crf_weights = [] # for v in tf.all_variables(): # # if v.name[0:4]=='bila': # print(str(v)) # crf_weights.append(v.name) # elif v.name[0:4] =='spat': # print(str(v)) # crf_weights.append(v.name) # elif v.name[0:4] =='comp': # print(str(v)) # crf_weights.append(v.name) # restore_var = [v for v in tf.all_variables() if v.name not in crf_weights] # # load_saver = tf.train.Saver(var_list=restore_var) # load_saver.restore(sess, '/scratch_net/biwirender02/cany/basil/logdir/unet2D_ws_spot_blur/recursion_0_model.ckpt-5699') if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_dice = 0 logging.info('RECURSION {0}'.format(recursion)) # random walk - if it already has been random walked it won't redo recursion_data = acdc_data.random_walk_epoch( recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) #get ground truths labels_train = np.array(recursion_data['random_walked']) for epoch in range(start_epoch, exp_config.max_epochs): if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0) \ or loaded_previous_recursion: loaded_previous_recursion = False #Have reached end of recursion recursion_data = predict_next_gt( data=recursion_data, images_train=images_train, images_placeholder=images_placeholder, training_time_placeholder=training_time_placeholder, logits=logits, sess=sess) recursion_data = postprocess_gt( data=recursion_data, images_train=images_train, scribbles_train=scribbles_train) recursion += 1 # random walk - if it already has been random walked it won't redo recursion_data = acdc_data.random_walk_epoch( recursion_data, exp_config.rw_beta, exp_config.rw_threshold, exp_config.random_walk) #get ground truths labels_train = np.array(recursion_data['random_walked']) #reinitialise savers - otherwise, no checkpoints will be saved for each recursion saver = tf.train.Saver(max_to_keep=2) saver_best_dice = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) logging.info( 'Epoch {0} ({1} of {2} epochs for recursion {3})'.format( epoch, 1 + epoch % exp_config.epochs_per_recursion, exp_config.epochs_per_recursion, recursion)) # for batch in iterate_minibatches(images_train, # labels_train, # batch_size=exp_config.batch_size, # augment_batch=exp_config.augment_batch): # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the # training speed. However, be aware that currently an exception inside this loop may not be caught. # The batch generator may just continue running silently without warning even though the code has # crashed. for batch in BackgroundGenerator( iterate_minibatches( images_train, labels_train, batch_size=exp_config.batch_size, augment_batch=exp_config.augment_batch)): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # batch = bgn_train.retrieve() x, y = batch # TEMPORARY HACK (to avoid incomplete batches if y.shape[0] < exp_config.batch_size: step += 1 continue feed_dict = { images_placeholder: x, labels_placeholder: y, learning_rate_placeholder: curr_lr, training_time_placeholder: True } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.6f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # if (step + 1) % exp_config.train_eval_frequency == 0: # # logging.info('Training Data Eval:') # [train_loss, train_dice] = do_eval(sess, # eval_loss, # images_placeholder, # labels_placeholder, # training_time_placeholder, # images_train, # labels_train, # exp_config.batch_size) # # train_summary_msg = sess.run(train_summary, feed_dict={train_error_: train_loss, # train_dice_: train_dice} # ) # summary_writer.add_summary(train_summary_msg, step) # # loss_history.append(train_loss) # if len(loss_history) > 5: # loss_history.pop(0) # loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 # # logging.info('loss gradient is currently %f' % loss_gradient) # # if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: # logging.warning('Reducing learning rate!') # curr_lr /= 10.0 # logging.info('Learning rate changed to: %f' % curr_lr) # # # reset loss history to give the optimisation some time to start decreasing again # loss_gradient = np.inf # loss_history = [] # # if train_loss <= last_train: # best_train: # logging.info('Decrease in training error!') # else: # logging.info('No improvement in training error for %d steps' % no_improvement_counter) # # last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join( log_dir, 'recursion_{}_model.ckpt'.format(recursion)) saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_dice ] = do_eval(sess, eval_val_loss, images_placeholder, labels_placeholder, training_time_placeholder, images_val, labels_val, exp_config.batch_size) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_dice_: val_dice }) summary_writer.add_summary(val_summary_msg, step) if val_dice > best_dice: best_dice = val_dice best_file = os.path.join( log_dir, 'recursion_{}_model_best_dice.ckpt'.format( recursion)) saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new best dice on validation set! - {} - ' 'Saving recursion_{}_model_best_dice.ckpt'. format(val_dice, recursion)) if val_loss < best_val: best_val = val_loss best_file = os.path.join( log_dir, 'recursion_{}_model_best_xent.ckpt'.format( recursion)) saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - {} - ' 'Saving recursion_{}_model_best_xent.ckpt'. format(val_loss, recursion)) step += 1 except Exception: raise
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_l2l) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth, target_resolution=exp_config.target_resolution_brain) gttr = data_brain_train['labels'] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth, target_resolution=exp_config.target_resolution_brain) gtvl = data_brain_val['labels'] logging.info( 'Training Labels: %s' % str(gttr.shape) ) # expected: [num_subjects, img_size_x, img_size_y, img_size_z] logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # visualize downsampled volumes # for subject_num in range(gttr.shape[0]): # utils_vis.save_samples_downsampled(gttr[subject_num, ::2, :, :], # savepath = log_dir + '/training_image_' + str(subject_num+1) + '.png') # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') true_labels_shape = [exp_config.batch_size] + list( exp_config.image_size) true_labels_pl = tf.placeholder(tf.uint8, shape=true_labels_shape, name='true_labels') # ================================================================ # This will be a mask with all zeros in locations of pixels that we want to alter the labels of. # Multiply with this mask to have zero vectors for all those pixels. # ================================================================ blank_masks_shape = [exp_config.batch_size] + list( exp_config.image_size) + [exp_config.nlabels] blank_masks_pl = tf.placeholder(tf.float32, shape=blank_masks_shape, name='blank_masks') # ================================================================ # This will be a mask with all zeros in locations of pixels that we want to alter the labels of. # Multiply with this mask to have zero vectors for all those pixels. # ================================================================ wrong_labels_shape = [exp_config.batch_size] + list( exp_config.image_size) + [exp_config.nlabels] wrong_labels_pl = tf.placeholder(tf.float32, shape=wrong_labels_shape, name='wrong_labels') # ================================================================ # Training placeholder # ================================================================ training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # make true labels 1-hot # ================================================================ true_labels_1hot = tf.one_hot(true_labels_pl, depth=exp_config.nlabels) # ================================================================ # Blank certain locations and write wrong labels in those locations # ================================================================ noisy_labels_1hot = tf.math.multiply(true_labels_1hot, blank_masks_pl) + wrong_labels_pl # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ autoencoded_logits, _, _ = model.predict_l2l(noisy_labels_1hot, exp_config, training_pl=training_pl) print('shape of input tensor: ', true_labels_pl.shape) # (batch_size, 64, 256, 256) print('shape of input tensor converted to 1-hot: ', true_labels_1hot.shape) # (batch_size, 64, 256, 256, 15) print('shape of predicted logits: ', autoencoded_logits.shape) # (batch_size, 64, 256, 256, 15) # ================================================================ # create a list of all vars that must be optimized wrt # ================================================================ l2l_vars = [] for v in tf.trainable_variables(): print(v.name) l2l_vars.append(v) # ================================================================ # add ops for calculation of the supervised training loss # ================================================================ loss_op = model.loss(logits=autoencoded_logits, labels=true_labels_1hot, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_l2l, are_labels_1hot=True) tf.summary.scalar('loss', loss_op) # ================================================================ # add optimization ops. # Create different ops according to the variables that must be trained # ================================================================ print('creating training op...') train_op = model.training_step(loss_op, l2l_vars, exp_config.optimizer_handle, exp_config.learning_rate_l2l, update_bn_nontrainable_vars=True) # ================================================================ # add ops for model evaluation # ================================================================ print('creating eval op...') eval_loss = model.evaluation_l2l(logits=autoencoded_logits, labels=true_labels_1hot, labels_masked=noisy_labels_1hot, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_l2l, are_labels_1hot=True) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ print('creating summary op...') summary = tf.summary.merge_all() # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create saver # ================================================================ saver = tf.train.Saver(max_to_keep=10) saver_best_dice = tf.train.Saver(max_to_keep=3) # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of all variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of all variables:') for v in tf.trainable_variables(): print(v.name) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step best_dice = 0 # ================================================================ # run training epochs # ================================================================ while (step < exp_config.max_steps_l2l): if step % 1000 is 0: logging.info( '============================================================' ) logging.info('step %d' % step) # ================================================ # batches # ================================================ for batch in iterate_minibatches(gttr, exp_config.batch_size): start_time = time.time() true_labels, blank_masks, wrong_labels = batch # =========================== # avoid incomplete batches # =========================== if true_labels.shape[0] < exp_config.batch_size: step += 1 continue # =========================== # create the feed dict for this training iteration # =========================== feed_dict = { true_labels_pl: true_labels, blank_masks_pl: blank_masks, wrong_labels_pl: wrong_labels, training_pl: True } # =========================== # opt step # =========================== _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.3f (%.3f sec for the last step)' % (step + 1, loss, duration)) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') train_loss, train_dice = do_eval(sess, eval_loss, true_labels_pl, blank_masks_pl, wrong_labels_pl, training_pl, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # at some frequency, visualize the noisy and clean segmentation pairs used for training the DAE noisy_labels = sess.run(noisy_labels_1hot, feed_dict={ true_labels_pl: true_labels, blank_masks_pl: blank_masks, wrong_labels_pl: wrong_labels }) noisy_labels = np.argmax(noisy_labels, axis=-1) basepath = log_dir + '/training_data/iter' + str(step) for zz in np.arange(20, 50, 10): utils_vis.save_single_image( true_labels[0, zz, :, :], basepath + '_slice' + str(zz) + '_clean.png', 15, True, 'tab20', False) utils_vis.save_single_image( noisy_labels[0, zz, :, :], basepath + '_slice' + str(zz) + '_noisy.png', 15, True, 'tab20', False) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('Validation Data Eval:') val_loss, val_dice = do_eval(sess, eval_loss, true_labels_pl, blank_masks_pl, wrong_labels_pl, training_pl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading training data from: ' + sys_config.project_data_root) data_tr = data_freiburg_numpy_to_hdf5.load_data( basepath=sys_config.project_data_root, idx_start=0, idx_end=19, train_test='train') images_tr = data_tr['images_train'] labels_tr = data_tr['labels_train'] logging.info( 'Shape of training images: %s' % str(images_tr.shape) ) # expected: [img_size_z*num_images, img_size_x, vol_size_y, img_size_t, n_channels] logging.info( 'Shape of training labels: %s' % str(labels_tr.shape) ) # expected: [img_size_z*num_images, img_size_x, vol_size_y, img_size_t] logging.info( '============================================================') logging.info('Loading validation data from: ' + sys_config.project_data_root) data_vl = data_freiburg_numpy_to_hdf5.load_data( basepath=sys_config.project_data_root, idx_start=20, idx_end=24, train_test='validation') images_vl = data_vl['images_validation'] labels_vl = data_vl['labels_validation'] logging.info('Shape of validation images: %s' % str(images_vl.shape)) logging.info('Shape of validation labels: %s' % str(labels_vl.shape)) logging.info( '============================================================') if exp_config.nchannels is 1: logging.info( '============================================================') logging.info( 'Only the proton density images (channel 0) will be used for the segmentation...' ) logging.info( '============================================================') # visualize some training images and their labels visualize_images = False if visualize_images is True: for sub_tr in range(20): utils.save_sample_image_and_labels_across_z( images_tr[sub_tr * 32:(sub_tr + 1) * 32, ...], labels_tr[sub_tr * 32:(sub_tr + 1) * 32, ...], log_dir + '/training_subject' + str(sub_tr)) utils.save_sample_image_and_labels_across_t( images_tr[sub_tr * 32:(sub_tr + 1) * 32, ...], labels_tr[sub_tr * 32:(sub_tr + 1) * 32, ...], log_dir + '/training_subject' + str(sub_tr)) # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [exp_config.nchannels] label_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=label_tensor_shape, name='labels') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # Build the graph that computes predictions from the inference model # ================================================================ logits = model.inference(images_pl, exp_config.model_handle, training_pl) # ================================================================ # Add ops for calculation of the training loss # ================================================================ loss = model.loss(logits, labels_pl, exp_config.nlabels, loss_type=exp_config.loss_type) # ================================================================ # Add the loss to tensorboard for visualizing its evolution as training proceeds # ================================================================ tf.summary.scalar('loss', loss) # ================================================================ # Add optimization ops # ================================================================ train_op = model.training_step(loss, exp_config.optimizer_handle, exp_config.learning_rate) # ================================================================ # Add ops for model evaluation # ================================================================ eval_loss = model.evaluation(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type) # ================================================================ # Build the summary Tensor based on the TF collection of Summaries. # ================================================================ summary = tf.summary.merge_all() # ================================================================ # Add init ops # ================================================================ init_op = tf.global_variables_initializer() # ================================================================ # Find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create savers for each domain # ================================================================ max_to_keep = 15 saver = tf.train.Saver(max_to_keep=max_to_keep) saver_best_dice = tf.train.Saver() # ================================================================ # Create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_op) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step best_dice = 0 # ================================================================ # run training epochs # ================================================================ while step < exp_config.max_steps: logging.info( '============================================================') logging.info('Step %d' % step) for batch in iterate_minibatches(images_tr, labels_tr, batch_size=exp_config.batch_size): x, y = batch # =========================== # run training iteration # =========================== feed_dict = {images_pl: x, labels_pl: y, training_pl: True} _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info('Step %d: loss = %.2f' % (step + 1, loss_value)) # =========================== # update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') [train_loss, train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_tr, labels_tr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # evaluate the model on the validation set # =========================== if step % exp_config.val_eval_frequency == 0: # =========================== # Evaluate against the validation set # =========================== logging.info('Validation Data Eval:') [val_loss, val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, images_vl, labels_vl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()