def evaluate(hps, design): """Eval loop.""" eval_records = _get_tfrecord_files_from_dir( FLAGS.eval_data_path) #get tfrecord files for train eval_iterator = petct_input.build_input(eval_records, hps.batch_size, hps.num_epochs, FLAGS.mode) eval_iterator_handle = eval_iterator.string_handle() #handle = tf.placeholder(tf.string, shape=[], name='data') #iterator = Iterator.from_string_handle(handle, eval_iterator.output_types, eval_iterator.output_shapes) #ct, pt, ctlb, ptlb, bglb = iterator.get_next() #model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb, FLAGS.mode) #model.build_cross_modal_model() # put get metrics ops here for train and val #eval_summary_op, eval_precision_op, eval_recall_op, eval_accuracy_op, eval_rmse_op = get_metrics_ops(model) # needed for input handlers g_init_op = tf.global_variables_initializer() l_init_op = tf.local_variables_initializer() with tf.Session(config=tf.ConfigProto( allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess: mon_sess.run([g_init_op, l_init_op]) ckpt_meta = FLAGS.log_root + '/' + FLAGS.chkpt_file + str( FLAGS.eval_chkpt_num) + '-end.ckpt.meta' meta_restore = tf.train.import_meta_graph(ckpt_meta) #ckpt_saver = tf.train.Saver() eval_writer = tf.summary.FileWriter(FLAGS.eval_dir) eval_handle = mon_sess.run(eval_iterator_handle) try: ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root) except tf.errors.OutOfRangeError as e: tf.logging.error('Cannot restore checkpoint: %s', e) sys.exit(0) if not (ckpt_state and ckpt_state.model_checkpoint_path): tf.logging.info('No model to eval yet at %s', FLAGS.log_root) sys.exit(0) tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) #ckpt_saver.restore(mon_sess, ckpt_state.model_checkpoint_path) meta_restore.restore(mon_sess, ckpt_state.model_checkpoint_path) # get all the tensors and operations that need to be fed during evaluation handle = tf.get_default_graph().get_tensor_by_name( 'data:0') # data will be fed here #works train_mode = tf.get_default_graph().get_tensor_by_name( 'train_mode:0') # will be set to False to turn off BN #works # get all the tensors and operations that need to be monitored during evaluation # needed to get metrics #eval_summary_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_summary/valid_summary:0')#works #eval_precision_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_precision/update_op:0')#works #eval_recall_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_recall/update_op:0')#works #eval_accuracy_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_accuracy/update_op:0')#works #eval_rmse_op = tf.get_default_graph().get_tensor_by_name('metrics/Sqrt_3:0')#works BUT not init (maybe ignore?) all_probabilities = tf.get_default_graph().get_tensor_by_name( 'costs/all_probabilities:0') #works all_pred = tf.get_default_graph().get_tensor_by_name( 'costs/all_prediction:0') #works ct_img = tf.get_default_graph().get_tensor_by_name('ct:0') #works pt_img = tf.get_default_graph().get_tensor_by_name('pt:0') #works lb_pos_gt = tf.get_default_graph().get_tensor_by_name( 'lb_pos_gt:0') #works lbbg = tf.get_default_graph().get_tensor_by_name('lbbg:0') #works step = 0 while True: try: step = step + 1 # modify below! <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run( [ ct_img, pt_img, lb_pos_gt, lbbg, all_pred, all_probabilities ], feed_dict={ handle: eval_handle, train_mode: False }) #eval_summary, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run([eval_summary_op, eval_precision_op, eval_recall_op, eval_accuracy_op, eval_rmse_op, ct_img, pt_img, lb_pos_gt, lbbg, all_pred, all_probabilities], feed_dict={handle: eval_handle, train_mode: False}) print('[EVAL] STEP: %d' % (step)) #print('[EVAL] STEP: %d, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, p, r, a, e)) #eval_writer.add_summary(eval_summary, step) #eval_writer.flush() if FLAGS.IMSAVE > 0: if step % FLAGS.IMSAVE == 0: # only works for single style print('SAVING IMAGES') _saveImages(hps.batch_size, step, cts, pts, trallpos=trallpos, trbgs=trbgs, recon_all=recon_all, all_preds=all_preds) except tf.errors.OutOfRangeError: print('OUT OF DATA - ENDING') break
def train(hps, design): """Training loop.""" train_records = _get_tfrecord_files_from_dir( FLAGS.train_data_path) #get tfrecord files for train train_iterator = petct_input.build_input(train_records, hps.batch_size, hps.num_epochs, FLAGS.mode) train_iterator_handle = train_iterator.string_handle() if not FLAGS.val_data_path == '': # skip validation if no path val_records = _get_tfrecord_files_from_dir( FLAGS.val_data_path) # get tfrecord files for val val_iterator = petct_input.build_input(val_records, hps.batch_size, hps.num_epochs, 'valid') val_iterator_handle = val_iterator.string_handle() handle = tf.placeholder(tf.string, shape=[], name='data') iterator = Iterator.from_string_handle(handle, train_iterator.output_types, train_iterator.output_shapes) ct, pt, ctlb, ptlb, bglb = iterator.get_next() model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb, FLAGS.mode) model.build_cross_modal_model() # for use in loading later #tf.get_collection('model') #tf.add_to_collection('model',model) # put get metrics ops here for train and val with tf.variable_scope('metrics'): tr_summary_op, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op = get_metrics_ops( model, 'train') val_summary_op, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op = get_metrics_ops( model, 'valid') # needed for input handlers g_init_op = tf.global_variables_initializer() l_init_op = tf.local_variables_initializer() with tf.Session(config=tf.ConfigProto( allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess: # Need a saver to save and restore all the variables. saver = tf.train.Saver() if FLAGS.DEBUG: print('ENABLING DEBUG') mon_sess = tf_debug.LocalCLIDebugWrapperSession(mon_sess) mon_sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) training_handle = mon_sess.run(train_iterator_handle) if not FLAGS.val_data_path == '': # skip validation if no path validation_handle = mon_sess.run(val_iterator_handle) train_writer = tf.summary.FileWriter(FLAGS.log_root + '/train', mon_sess.graph) if not FLAGS.val_data_path == '': # skip validation if no path valid_writer = tf.summary.FileWriter(FLAGS.log_root + '/valid') mon_sess.run([g_init_op, l_init_op]) summary = None step = None val_summary = None #check = 1 while True: try: ## FIRST RUN TRAINING OP BASED ON OUTPUT STYLE if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT: # get PET and CT recons separately _, summary, step, loss, p, r, a, e, cts, pts, trcts, trpts, trbgs, recon_cts, recon_pts, ct_preds, pt_preds = mon_sess.run( [ model.train_op, tr_summary_op, model.global_step, model.cost, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op, model.ct, model.pt, model.lbct, model.lbpt, model.lbbg, model.ct_pred, model.pt_pred, model.ct_probabilities, model.pt_probabilities ], feed_dict={ handle: training_handle, model.is_training: True }) elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE: # get PET and CT recons together _, summary, step, loss, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run( [ model.train_op, tr_summary_op, model.global_step, model.cost, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op, model.ct, model.pt, model.lb_pos_gt, model.lbbg, model.all_pred, model.all_probabilities ], feed_dict={ handle: training_handle, model.is_training: True }) if step % FLAGS.train_iter == 0: print( '[TRAIN] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, loss, p, r, a, e)) train_writer.add_summary(summary, step) train_writer.flush() if FLAGS.IMSAVE > 0: if step % FLAGS.IMSAVE == 0: print('SAVING IMAGES') if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT: _saveImages(hps.batch_size, step, cts, pts, trcts=trcts, trpts=trpts, trbgs=trbgs, recon_cts=recon_cts, recon_pts=recon_pts, ct_preds=ct_preds, pt_preds=pt_preds) elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE: _saveImages(hps.batch_size, step, cts, pts, trallpos=trallpos, trbgs=trbgs, recon_all=recon_all, all_preds=all_preds) if not FLAGS.val_data_path == '': # skip validation if no path if step % FLAGS.val_iter == 0: _, val_summary, loss, p, r, a, e = mon_sess.run( [ model.val_op, val_summary_op, model.cost, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op ], feed_dict={ handle: validation_handle, model.is_training: False }) val_step = step print( '[VALID] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, loss, p, r, a, e)) valid_writer.add_summary(val_summary, step) valid_writer.flush() if step % FLAGS.chkpt_iter == 0: save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str( step) + '.ckpt' save_path = saver.save(mon_sess, save_loc) print('Model saved in path: %s' % save_path) except tf.errors.OutOfRangeError: print('OUT OF DATA - ENDING') # now finished training (either train or validation has run out) train_writer.add_summary(summary, step) train_writer.flush() if not FLAGS.val_data_path == '': # skip validation if no path valid_writer.add_summary(val_summary, val_step) valid_writer.flush() save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str( step) + '-end.ckpt' save_path = saver.save(mon_sess, save_loc) print('Model saved in path: %s' % save_path) break