def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) graph = tf.Graph() with graph.as_default(): with tf.device( tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)): assert FLAGS.train_batch_size % FLAGS.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.train_split, dataset_dir=FLAGS.dataset_dir, batch_size=clone_batch_size, crop_size=[int(sz) for sz in FLAGS.train_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, model_variant=FLAGS.model_variant, num_readers=2, is_training=True, should_shuffle=True, should_repeat=True) train_tensor, summary_op = _train_deeplab_model( dataset.get_one_shot_iterator(), dataset.num_of_classes, dataset.ignore_label) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) init_fn = None if FLAGS.tf_initial_checkpoint: init_fn = train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True) scaffold = tf.train.Scaffold( init_fn=init_fn, summary_op=summary_op, ) stop_hook = tf.train.StopAtStepHook( last_step=FLAGS.training_number_of_steps) profile_dir = FLAGS.profile_logdir if profile_dir is not None: tf.gfile.MakeDirs(profile_dir) with tf.contrib.tfprof.ProfileContext(enabled=profile_dir is not None, profile_dir=profile_dir): with tf.train.MonitoredTrainingSession( master=FLAGS.master, is_chief=(FLAGS.task == 0), config=session_config, scaffold=scaffold, checkpoint_dir=FLAGS.train_logdir, summary_dir=FLAGS.train_logdir, log_step_count_steps=FLAGS.log_steps, save_summaries_steps=FLAGS.save_summaries_secs, save_checkpoint_secs=FLAGS.save_interval_secs, hooks=[stop_hook]) as sess: while not sess.should_stop(): sess.run([train_tensor])
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset(FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones) #samples, capacity=12 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_unet #model_args = (inputs_queue, { # common.OUTPUT_TYPE: dataset.num_classes #}, dataset.ignore_label) model_args = (inputs_queue, dataset, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) #input('stop!') # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) #input('no training') # Start the training. slim.learning.train(train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_arg): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones tf.gfile.MakeDirs(FLAGS.train_dir) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples, num_samples = get_dataset.get_dataset( FLAGS.dataset, FLAGS.dataset_dir, split_name=FLAGS.train_split, is_training=True, image_size=[FLAGS.image_size, FLAGS.image_size], batch_size=clone_batch_size, channel=FLAGS.input_channel) tf.logging.info('Training on %s set: %d', FLAGS.train_split, num_samples) inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_model model_args = (inputs_queue, clone_batch_size) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. if FLAGS.save_summaries_variables: for model_var in slim.get_model_variables(): summaries.add( tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) #optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('losses/total_loss', total_loss)) # Modify the gradients for biases and last layer variables. if (FLAGS.dataset == 'protein') and FLAGS.add_counts_logits: last_layers = ['Logits', 'Counts_logits'] else: last_layers = ['Logits'] grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True session_config.gpu_options.per_process_gpu_memory_fraction = 0.9 # Start the training. slim.learning.train(train_tensor, FLAGS.train_dir, is_chief=(FLAGS.task == 0), master=FLAGS.master, graph=graph, log_every_n_steps=FLAGS.log_every_n_steps, session_config=session_config, startup_delay_steps=startup_delay_steps, number_of_steps=FLAGS.number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, init_fn=train_utils.get_model_init_fn( FLAGS.train_dir, FLAGS.fine_tune_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, saver=tf.train.Saver(max_to_keep=50))
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size / config.num_clones # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset( FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default(): with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue( samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab model_args = (inputs_queue, { common.OUTPUT_TYPE: dataset.num_classes }, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # Start the training. slim.learning.train( train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv): # Check model parameters check_model_conflict() data_inforamtion = data_generator._DATASETS_INFORMATION[FLAGS.dataset_name] tf.logging.set_verbosity(tf.logging.INFO) tf.gfile.MakeDirs(FLAGS.train_logdir) for split in FLAGS.train_split: tf.logging.info('Training on %s set', split) path = FLAGS.train_logdir parameters_dict = vars(FLAGS) with open(os.path.join(path, 'json.txt'), 'w', encoding='utf-8') as f: json.dump(parameters_dict, f, indent=3) with open(os.path.join(path, 'logging.txt'), 'w') as f: for key in parameters_dict: f.write("{}: {}".format(str(key), str(parameters_dict[key]))) f.write("\n") f.write("\nStart time: {}".format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) f.write("\n") graph = tf.Graph() with graph.as_default(): with tf.device( tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)): assert FLAGS.batch_size % FLAGS.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.batch_size // FLAGS.num_clones if FLAGS.dataset_name == '2019_ISBI_CHAOS_MR_T1' or FLAGS.dataset_name == '2019_ISBI_CHAOS_MR_T2': min_resize_value = data_inforamtion.height max_resize_value = data_inforamtion.height else: if FLAGS.min_resize_value is not None: min_resize_value = FLAGS.min_resize_value else: min_resize_value = data_inforamtion.height if FLAGS.max_resize_value is not None: max_resize_value = FLAGS.max_resize_value else: max_resize_value = data_inforamtion.height train_generator = data_generator.Dataset( dataset_name=FLAGS.dataset_name, split_name=FLAGS.train_split, guidance_type=FLAGS.guidance_type, batch_size=clone_batch_size, pre_crop_flag=FLAGS.pre_crop_flag, mt_class=FLAGS.mt_output_node, crop_size=data_inforamtion.train["train_crop_size"], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, num_readers=2, is_training=True, shuffle_data=True, repeat_data=True, prior_num_slice=FLAGS.prior_num_slice, prior_num_subject=FLAGS.prior_num_subject, seq_length=FLAGS.seq_length, seq_type="bidirection", z_loss_name=FLAGS.z_loss_name, ) if "val" not in FLAGS.train_split: val_generator = data_generator.Dataset( dataset_name=FLAGS.dataset_name, split_name=["val"], guidance_type=FLAGS.guidance_type, batch_size=1, mt_class=FLAGS.mt_output_node, crop_size=[ data_inforamtion.height, data_inforamtion.width ], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, num_readers=2, is_training=False, shuffle_data=False, repeat_data=True, prior_num_slice=FLAGS.prior_num_slice, prior_num_subject=FLAGS.prior_num_subject, seq_length=FLAGS.seq_length, seq_type="bidirection", z_loss_name=FLAGS.z_loss_name, ) model_options = common.ModelOptions( outputs_to_num_classes=train_generator.num_of_classes, crop_size=data_inforamtion.train["train_crop_size"], output_stride=FLAGS.output_stride) steps = tf.compat.v1.placeholder(tf.int32, shape=[]) dataset1 = train_generator.get_dataset() iter1 = dataset1.make_one_shot_iterator() train_samples = iter1.get_next() train_tensor, summary_op = _train_pgn_model( train_samples, train_generator.num_of_classes, model_options, train_generator.ignore_label) if "val" not in FLAGS.train_split: dataset2 = val_generator.get_dataset() iter2 = dataset2.make_one_shot_iterator() val_samples = iter2.get_next() val_tensor, _ = _val_pgn_model(val_samples, val_generator.num_of_classes, model_options, val_generator.ignore_label, steps) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) init_fn = None if FLAGS.tf_initial_checkpoint: init_fn = train_utils.get_model_init_fn( train_logdir=FLAGS.train_logdir, tf_initial_checkpoint=FLAGS.tf_initial_checkpoint, initialize_first_layer=True, initialize_last_layer=FLAGS.initialize_last_layer, ignore_missing_vars=True) scaffold = tf.train.Scaffold( init_fn=init_fn, summary_op=summary_op, ) stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps) saver = tf.train.Saver() best_dice = 0 with tf.train.MonitoredTrainingSession( master=FLAGS.master, is_chief=(FLAGS.task == 0), config=session_config, scaffold=scaffold, checkpoint_dir=FLAGS.train_logdir, log_step_count_steps=FLAGS.log_steps, save_summaries_steps=20, save_checkpoint_steps=FLAGS.save_checkpoint_steps, hooks=[stop_hook]) as sess: # step=0 total_val_loss, total_val_steps = [], [] best_model_performance = 0.0 while not sess.should_stop(): _, global_step = sess.run( [train_tensor, tf.train.get_global_step()]) if "val" not in FLAGS.train_split: if global_step % FLAGS.validation_steps == 0: cm_total = 0 for j in range( val_generator.splits_to_sizes["val"]): cm_total += sess.run(val_tensor, feed_dict={steps: j}) mean_dice_score, _ = metrics.compute_mean_dsc( total_cm=cm_total) total_val_loss.append(mean_dice_score) total_val_steps.append(global_step) plt.legend(["validation loss"]) plt.xlabel("global step") plt.ylabel("loss") plt.plot(total_val_steps, total_val_loss, "bo-") plt.grid(True) plt.savefig(FLAGS.train_logdir + "/losses.png") if mean_dice_score > best_dice: best_dice = mean_dice_score saver.save( get_session(sess), os.path.join(FLAGS.train_logdir, 'model.ckpt-best')) # saver.save(get_session(sess), os.path.join(FLAGS.train_logdir, 'model.ckpt-best-%d' %global_step)) txt = 20 * ">" + " saving best mdoel model.ckpt-best-%d with DSC: %f" % ( global_step, best_dice) print(txt) with open(os.path.join(path, 'logging.txt'), 'a') as f: f.write(txt) f.write("\n") with open(os.path.join(path, 'logging.txt'), 'a') as f: f.write("\nEnd time: {}".format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) f.write("\n")