def main(_): config = tf.ConfigProto(inter_op_parallelism_threads=num_inter_op_threads, intra_op_parallelism_threads=num_intra_op_threads) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # For Tensorflow trace cluster = tf.train.ClusterSpec({"ps": ps_list, "worker": worker_list}) server = tf.train.Server(cluster, job_name=job_name, task_index=task_index) is_sync = (FLAGS.is_sync == 1) # Synchronous or asynchronous updates is_chief = (task_index == 0) # Am I the chief node (always task 0) greedy = tf.contrib.training.GreedyLoadBalancingStrategy( num_tasks=len(ps_hosts), load_fn=tf.contrib.training.byte_size_load_fn) if job_name == "ps": with tf.device( tf.train.replica_device_setter( worker_device="/job:ps/task:{}".format(task_index), ps_tasks=len(ps_hosts), ps_strategy=greedy, cluster=cluster)): sess = tf.Session(server.target, config=config) queue = create_done_queue(task_index) print("*" * 30) print("\nParameter server #{} on {}.\n\n" \ "Waiting on workers to finish.\n\nPress CTRL-\\ to terminate early.\n" \ .format(task_index, ps_hosts[task_index])) print("*" * 30) # wait until all workers are done for i in range(len(worker_hosts)): sess.run(queue.dequeue()) print("Worker #{} reports job finished.".format(i)) print("Parameter server #{} is quitting".format(task_index)) print("Training complete.") elif job_name == "worker": if is_chief: print("I am chief worker {} with task #{}".format( worker_hosts[task_index], task_index)) else: print("I am worker {} with task #{}".format( worker_hosts[task_index], task_index)) if len(ps_list) > 0: setDevice = tf.train.replica_device_setter( worker_device="/job:worker/task:{}".format(task_index), ps_tasks=len(ps_hosts), ps_strategy=greedy, cluster=cluster) else: setDevice = "/cpu:0" # No parameter server so put variables on chief worker with tf.device(setDevice): global_step = tf.Variable(0, name="global_step", trainable=False) # Load the data imgs_train, msks_train, imgs_test, msks_test = load_all_data() train_length = imgs_train.shape[0] # Number of train datasets test_length = imgs_test.shape[0] # Number of test datasets """ BEGIN: Define our model """ imgs = tf.placeholder(tf.float32, shape=(None, msks_train.shape[1], msks_train.shape[2], msks_train.shape[3])) msks = tf.placeholder(tf.float32, shape=(None, msks_train.shape[1], msks_train.shape[2], msks_train.shape[3])) preds = define_model(imgs, FLAGS.use_upsampling, settings_dist.OUT_CHANNEL_NO) print('Model defined') loss_value = dice_coef_loss(msks, preds) dice_value = dice_coef(msks, preds) sensitivity_value = sensitivity(msks, preds) specificity_value = specificity(msks, preds) test_loss_value = tf.placeholder(tf.float32, ()) test_dice_value = tf.placeholder(tf.float32, ()) test_sensitivity_value = tf.placeholder(tf.float32, ()) test_specificity_value = tf.placeholder(tf.float32, ()) """ END: Define our model """ # Decay learning rate from initial_learn_rate to initial_learn_rate*fraction in decay_steps global steps if FLAGS.const_learningrate: learning_rate = tf.convert_to_tensor(FLAGS.learning_rate, dtype=tf.float32) else: learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_steps, FLAGS.lr_fraction, staircase=False) # Compensate learning rate for asynchronous distributed # THEORY: We need to cut the learning rate by at least the number # of workers since there are likely to be that many times increased # parameter updates. # if not is_sync: # learning_rate /= len(worker_hosts) # optimizer = tf.train.GradientDescentOptimizer(learning_rate) # #optimizer = tf.train.AdagradOptimizer(learning_rate) # else: # optimizer = tf.train.AdamOptimizer(learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) grads_and_vars = optimizer.compute_gradients(loss_value) if is_sync: rep_op = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=len(worker_hosts), total_num_replicas=len(worker_hosts), use_locking=True) train_op = rep_op.apply_gradients(grads_and_vars, global_step=global_step) init_token_op = rep_op.get_init_tokens_op() chief_queue_runner = rep_op.get_chief_queue_runner() else: train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) init_op = tf.global_variables_initializer() saver = tf.train.Saver() # These are the values we wish to print to TensorBoard tf.summary.scalar("loss", loss_value) tf.summary.histogram("loss", loss_value) tf.summary.scalar("dice", dice_value) tf.summary.histogram("dice", dice_value) tf.summary.scalar("sensitivity", sensitivity_value) tf.summary.histogram("sensitivity", sensitivity_value) tf.summary.scalar("specificity", specificity_value) tf.summary.histogram("specificity", specificity_value) tf.summary.image("predictions", preds, max_outputs=settings_dist.TENSORBOARD_IMAGES) tf.summary.image("ground_truth", msks, max_outputs=settings_dist.TENSORBOARD_IMAGES) tf.summary.image("images", imgs, max_outputs=settings_dist.TENSORBOARD_IMAGES) print("Loading epoch") epoch = get_epoch(batch_size, imgs_train, msks_train) num_batches = len(epoch) print("Loaded") # Print the percent steps complete to TensorBoard # so that we know how much of the training remains. num_steps_tf = tf.constant(num_batches * FLAGS.epochs, tf.float32) percent_done_value = tf.constant(100.0) * tf.to_float( global_step) / num_steps_tf tf.summary.scalar("percent_complete", percent_done_value) # Need to remove the checkpoint directory before each new run # import shutil # shutil.rmtree(CHECKPOINT_DIRECTORY, ignore_errors=True) # Send a signal to the ps when done by simply updating a queue in the shared graph enq_ops = [] for q in create_done_queues(): qop = q.enqueue(1) enq_ops.append(qop) # Only the chief does the summary if is_chief: summary_op = tf.summary.merge_all() else: summary_op = None # Add summaries for test data # These summary ops are not part of the merge all op. # This way we can call these separately. test_loss_value = tf.placeholder(tf.float32, ()) test_dice_value = tf.placeholder(tf.float32, ()) test_loss_summary = tf.summary.scalar("loss_test", test_loss_value) test_dice_summary = tf.summary.scalar("dice_test", test_dice_value) test_sens_summary = tf.summary.scalar("sensitivity_test", test_sensitivity_value) test_spec_summary = tf.summary.scalar("specificity_test", test_specificity_value) # TODO: Theoretically I can pass the summary_op into # the Supervisor and have it handle the TensorBoard # log entries. However, doing so seems to hang the code. # For now, I just handle the summary calls explicitly. # import time # logDirName = CHECKPOINT_DIRECTORY + "/run" + \ # time.strftime("_%Y%m%d_%H%M%S") if FLAGS.use_upsampling: method_up = "upsample2D" else: method_up = "conv2DTranspose" logDirName = CHECKPOINT_DIRECTORY + "/unet," + \ "lr={},{},intra={},inter={}".format(FLAGS.learning_rate, method_up, num_intra_op_threads, num_inter_op_threads) sv = tf.train.Supervisor( is_chief=is_chief, logdir=logDirName, init_op=init_op, summary_op=None, saver=saver, global_step=global_step, save_model_secs=60 # Save the model (with weights) everty 60 seconds ) # TODO: # I'd like to use managed_session for this as it is more abstract # and probably less sensitive to changes from the TF team. However, # I am finding that the chief worker hangs on exit if I use managed_session. with sv.prepare_or_wait_for_session(server.target, config=config) as sess: #with sv.managed_session(server.target) as sess: if sv.is_chief and is_sync: sv.start_queue_runners(sess, [chief_queue_runner]) sess.run(init_token_op) step = 0 progressbar = trange(num_batches * FLAGS.epochs) last_step = 0 # Start TensorBoard on the chief worker if sv.is_chief: cmd = 'tensorboard --logdir={}'.format(CHECKPOINT_DIRECTORY) tb_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True, preexec_fn=os.setsid) while (not sv.should_stop()) and (step < (num_batches * FLAGS.epochs)): batch_idx = step % num_batches # Which batch is the epoch? data = epoch[batch_idx, 0] labels = epoch[batch_idx, 1] # For n workers, break up the batch into n sections # Send each worker a different section of the batch data_range = int(batch_size / len(worker_hosts)) start = data_range * task_index end = start + data_range feed_dict = {imgs: data[start:end], msks: labels[start:end]} history, loss_v, dice_v, step = sess.run( [train_op, loss_value, dice_value, global_step], feed_dict=feed_dict) # Print summary only on chief if sv.is_chief: summary = sess.run(summary_op, feed_dict=feed_dict) sv.summary_computed(sess, summary) # Update the summary # Calculate metric on test dataset every epoch if (batch_idx == 0) and (step > num_batches): dice_v_test = 0.0 loss_v_test = 0.0 sens_v_test = 0.0 spec_v_test = 0.0 for idx in tqdm( range(0, imgs_test.shape[0] - batch_size, batch_size), desc="Calculating metrics on test dataset", leave=False): x_test = imgs_test[idx:(idx + batch_size)] y_test = msks_test[idx:(idx + batch_size)] feed_dict = {imgs: x_test, msks: y_test} l_v, d_v, st_v, sp_v = sess.run( [ loss_value, dice_value, sensitivity_value, specificity_value ], feed_dict=feed_dict) dice_v_test += d_v / (test_length // batch_size) loss_v_test += l_v / (test_length // batch_size) sens_v_test += st_v / (test_length // batch_size) spec_v_test += sp_v / (test_length // batch_size) print("\nEpoch {} of {}: TEST DATASET\nloss = {:.4f}\nDice = {:.4f}\n" \ "Sensitivity = {:.4f}\nSpecificity = {:.4f}" \ .format((step // num_batches), FLAGS.epochs, loss_v_test, dice_v_test, sens_v_test, spec_v_test)) # Add our test summary metrics to TensorBoard sv.summary_computed( sess, sess.run(test_loss_summary, feed_dict={test_loss_value: loss_v_test})) sv.summary_computed( sess, sess.run(test_dice_summary, feed_dict={test_dice_value: dice_v_test})) sv.summary_computed( sess, sess.run(test_sens_summary, feed_dict={ test_sensitivity_value: sens_v_test })) sv.summary_computed( sess, sess.run(test_spec_summary, feed_dict={ test_specificity_value: spec_v_test })) saver.save( sess, CHECKPOINT_DIRECTORY + "/last_good_model.cpkt") # Shuffle every epoch if (batch_idx == 0) and (step > num_batches): print("Shuffling epoch") epoch = get_epoch(batch_size, imgs_train, msks_train) # Print the loss and dice metric in the progress bar. progressbar.set_description( "(loss={:.4f}, dice={:.4f})".format(loss_v, dice_v)) progressbar.update(step - last_step) last_step = step # Perform the final test set metric if sv.is_chief: dice_v_test = 0.0 loss_v_test = 0.0 for idx in tqdm(range(0, imgs_test.shape[0] - batch_size, batch_size), desc="Calculating metrics on test dataset", leave=False): x_test = imgs_test[idx:(idx + batch_size)] y_test = msks_test[idx:(idx + batch_size)] feed_dict = {imgs: x_test, msks: y_test} l_v, d_v = sess.run([loss_value, dice_value], feed_dict=feed_dict) dice_v_test += d_v / (test_length // batch_size) loss_v_test += l_v / (test_length // batch_size) print("\nEpoch {} of {}: Test loss = {:.4f}, Test Dice = {:.4f}" \ .format((step // num_batches), FLAGS.epochs, loss_v_test, dice_v_test)) sv.summary_computed( sess, sess.run(test_loss_summary, feed_dict={test_loss_value: loss_v_test})) sv.summary_computed( sess, sess.run(test_dice_summary, feed_dict={test_dice_value: dice_v_test})) saver.save(sess, CHECKPOINT_DIRECTORY + "/last_good_model.cpkt") if sv.is_chief: export_model( sess, imgs, preds ) # Save the final model as protbuf for TensorFlow Serving os.killpg(os.getpgid(tb_process.pid), signal.SIGTERM) # Stop TensorBoard process # Send a signal to the ps when done by simply updating a queue in the shared graph for op in enq_ops: sess.run( op ) # Send the "work completed" signal to the parameter server print("\n\nFinished work on this node.") import time time.sleep(3) # Sleep for 3 seconds then exit sv.request_stop()
args.num_channels) img = tf.placeholder(tf.float32, shape=shape) # Input tensor msk = tf.placeholder(tf.float32, shape=shape) # Label tensor # Define the model # Predict the output mask preds = define_model(img, learning_rate=args.lr, use_upsampling=args.use_upsampling, print_summary=args.print_model) # Performance metrics for model loss = dice_coef_loss(msk, preds) # Loss is the dice between mask and prediction dice_score = dice_coef(msk, preds) sensitivity_score = sensitivity(msk, preds) specificity_score = specificity(msk, preds) train_op = tf.train.AdamOptimizer(args.lr).minimize(loss, global_step=global_step) # Just feed completely random data in for the benchmark testing imgs = np.random.rand(args.bz, args.dim_length, args.dim_length, args.dim_length, args.num_channels) msks = imgs + np.random.rand(args.bz, args.dim_length, args.dim_length, args.dim_length, args.num_channels) # Initialize all variables init_op = tf.global_variables_initializer() sess.run(init_op)