def evaluate(hps, design):
    """Eval loop."""

    eval_records = _get_tfrecord_files_from_dir(
        FLAGS.eval_data_path)  #get tfrecord files for train
    eval_iterator = petct_input.build_input(eval_records, hps.batch_size,
                                            hps.num_epochs, FLAGS.mode)
    eval_iterator_handle = eval_iterator.string_handle()

    #handle = tf.placeholder(tf.string, shape=[], name='data')
    #iterator = Iterator.from_string_handle(handle, eval_iterator.output_types, eval_iterator.output_shapes)
    #ct, pt, ctlb, ptlb, bglb = iterator.get_next()

    #model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb, FLAGS.mode)
    #model.build_cross_modal_model()
    # put get metrics ops here for train and val
    #eval_summary_op, eval_precision_op, eval_recall_op, eval_accuracy_op, eval_rmse_op = get_metrics_ops(model)

    # needed for input handlers
    g_init_op = tf.global_variables_initializer()
    l_init_op = tf.local_variables_initializer()

    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess:
        mon_sess.run([g_init_op, l_init_op])

        ckpt_meta = FLAGS.log_root + '/' + FLAGS.chkpt_file + str(
            FLAGS.eval_chkpt_num) + '-end.ckpt.meta'
        meta_restore = tf.train.import_meta_graph(ckpt_meta)

        #ckpt_saver = tf.train.Saver()
        eval_writer = tf.summary.FileWriter(FLAGS.eval_dir)

        eval_handle = mon_sess.run(eval_iterator_handle)

        try:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
        except tf.errors.OutOfRangeError as e:
            tf.logging.error('Cannot restore checkpoint: %s', e)
            sys.exit(0)
        if not (ckpt_state and ckpt_state.model_checkpoint_path):
            tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
            sys.exit(0)
        tf.logging.info('Loading checkpoint %s',
                        ckpt_state.model_checkpoint_path)
        #ckpt_saver.restore(mon_sess, ckpt_state.model_checkpoint_path)
        meta_restore.restore(mon_sess, ckpt_state.model_checkpoint_path)

        # get all the tensors and operations that need to be fed during evaluation
        handle = tf.get_default_graph().get_tensor_by_name(
            'data:0')  # data will be fed here #works
        train_mode = tf.get_default_graph().get_tensor_by_name(
            'train_mode:0')  # will be set to False to turn off BN #works

        # get all the tensors and operations that need to be monitored during evaluation
        # needed to get metrics
        #eval_summary_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_summary/valid_summary:0')#works
        #eval_precision_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_precision/update_op:0')#works
        #eval_recall_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_recall/update_op:0')#works
        #eval_accuracy_op = tf.get_default_graph().get_tensor_by_name('metrics/valid_accuracy/update_op:0')#works
        #eval_rmse_op = tf.get_default_graph().get_tensor_by_name('metrics/Sqrt_3:0')#works BUT not init (maybe ignore?)
        all_probabilities = tf.get_default_graph().get_tensor_by_name(
            'costs/all_probabilities:0')  #works
        all_pred = tf.get_default_graph().get_tensor_by_name(
            'costs/all_prediction:0')  #works
        ct_img = tf.get_default_graph().get_tensor_by_name('ct:0')  #works
        pt_img = tf.get_default_graph().get_tensor_by_name('pt:0')  #works
        lb_pos_gt = tf.get_default_graph().get_tensor_by_name(
            'lb_pos_gt:0')  #works
        lbbg = tf.get_default_graph().get_tensor_by_name('lbbg:0')  #works

        step = 0
        while True:

            try:
                step = step + 1
                # modify below! <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
                cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run(
                    [
                        ct_img, pt_img, lb_pos_gt, lbbg, all_pred,
                        all_probabilities
                    ],
                    feed_dict={
                        handle: eval_handle,
                        train_mode: False
                    })
                #eval_summary, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run([eval_summary_op, eval_precision_op, eval_recall_op, eval_accuracy_op, eval_rmse_op, ct_img, pt_img, lb_pos_gt, lbbg, all_pred, all_probabilities], feed_dict={handle: eval_handle, train_mode: False})
                print('[EVAL] STEP: %d' % (step))
                #print('[EVAL] STEP: %d, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f' % (step, p, r, a, e))
                #eval_writer.add_summary(eval_summary, step)
                #eval_writer.flush()

                if FLAGS.IMSAVE > 0:
                    if step % FLAGS.IMSAVE == 0:
                        # only works for single style
                        print('SAVING IMAGES')
                        _saveImages(hps.batch_size,
                                    step,
                                    cts,
                                    pts,
                                    trallpos=trallpos,
                                    trbgs=trbgs,
                                    recon_all=recon_all,
                                    all_preds=all_preds)

            except tf.errors.OutOfRangeError:
                print('OUT OF DATA - ENDING')
                break
def train(hps, design):
    """Training loop."""
    train_records = _get_tfrecord_files_from_dir(
        FLAGS.train_data_path)  #get tfrecord files for train
    train_iterator = petct_input.build_input(train_records, hps.batch_size,
                                             hps.num_epochs, FLAGS.mode)
    train_iterator_handle = train_iterator.string_handle()

    if not FLAGS.val_data_path == '':  # skip validation if no path
        val_records = _get_tfrecord_files_from_dir(
            FLAGS.val_data_path)  # get tfrecord files for val
        val_iterator = petct_input.build_input(val_records, hps.batch_size,
                                               hps.num_epochs, 'valid')
        val_iterator_handle = val_iterator.string_handle()

    handle = tf.placeholder(tf.string, shape=[], name='data')
    iterator = Iterator.from_string_handle(handle, train_iterator.output_types,
                                           train_iterator.output_shapes)
    ct, pt, ctlb, ptlb, bglb = iterator.get_next()

    model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb,
                                   FLAGS.mode)
    model.build_cross_modal_model()

    # for use in loading later
    #tf.get_collection('model')
    #tf.add_to_collection('model',model)

    # put get metrics ops here for train and val
    with tf.variable_scope('metrics'):
        tr_summary_op, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op = get_metrics_ops(
            model, 'train')
        val_summary_op, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op = get_metrics_ops(
            model, 'valid')

    # needed for input handlers
    g_init_op = tf.global_variables_initializer()
    l_init_op = tf.local_variables_initializer()

    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess:
        # Need a saver to save and restore all the variables.
        saver = tf.train.Saver()

        if FLAGS.DEBUG:
            print('ENABLING DEBUG')
            mon_sess = tf_debug.LocalCLIDebugWrapperSession(mon_sess)
            mon_sess.add_tensor_filter("has_inf_or_nan",
                                       tf_debug.has_inf_or_nan)

        training_handle = mon_sess.run(train_iterator_handle)
        if not FLAGS.val_data_path == '':  # skip validation if no path
            validation_handle = mon_sess.run(val_iterator_handle)

        train_writer = tf.summary.FileWriter(FLAGS.log_root + '/train',
                                             mon_sess.graph)

        if not FLAGS.val_data_path == '':  # skip validation if no path
            valid_writer = tf.summary.FileWriter(FLAGS.log_root + '/valid')

        mon_sess.run([g_init_op, l_init_op])

        summary = None
        step = None
        val_summary = None
        #check = 1
        while True:
            try:
                ## FIRST RUN TRAINING OP BASED ON OUTPUT STYLE
                if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT:
                    # get PET and CT recons separately
                    _, summary, step, loss, p, r, a, e, cts, pts, trcts, trpts, trbgs, recon_cts, recon_pts, ct_preds, pt_preds = mon_sess.run(
                        [
                            model.train_op, tr_summary_op, model.global_step,
                            model.cost, tr_precision_op, tr_recall_op,
                            tr_accuracy_op, tr_rmse_op, model.ct, model.pt,
                            model.lbct, model.lbpt, model.lbbg, model.ct_pred,
                            model.pt_pred, model.ct_probabilities,
                            model.pt_probabilities
                        ],
                        feed_dict={
                            handle: training_handle,
                            model.is_training: True
                        })
                elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE:
                    # get PET and CT recons together
                    _, summary, step, loss, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run(
                        [
                            model.train_op, tr_summary_op, model.global_step,
                            model.cost, tr_precision_op, tr_recall_op,
                            tr_accuracy_op, tr_rmse_op, model.ct, model.pt,
                            model.lb_pos_gt, model.lbbg, model.all_pred,
                            model.all_probabilities
                        ],
                        feed_dict={
                            handle: training_handle,
                            model.is_training: True
                        })

                if step % FLAGS.train_iter == 0:
                    print(
                        '[TRAIN] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f'
                        % (step, loss, p, r, a, e))
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if FLAGS.IMSAVE > 0:
                    if step % FLAGS.IMSAVE == 0:
                        print('SAVING IMAGES')
                        if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT:
                            _saveImages(hps.batch_size,
                                        step,
                                        cts,
                                        pts,
                                        trcts=trcts,
                                        trpts=trpts,
                                        trbgs=trbgs,
                                        recon_cts=recon_cts,
                                        recon_pts=recon_pts,
                                        ct_preds=ct_preds,
                                        pt_preds=pt_preds)
                        elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE:
                            _saveImages(hps.batch_size,
                                        step,
                                        cts,
                                        pts,
                                        trallpos=trallpos,
                                        trbgs=trbgs,
                                        recon_all=recon_all,
                                        all_preds=all_preds)

                if not FLAGS.val_data_path == '':  # skip validation if no path
                    if step % FLAGS.val_iter == 0:
                        _, val_summary, loss, p, r, a, e = mon_sess.run(
                            [
                                model.val_op, val_summary_op, model.cost,
                                val_precision_op, val_recall_op,
                                val_accuracy_op, val_rmse_op
                            ],
                            feed_dict={
                                handle: validation_handle,
                                model.is_training: False
                            })
                        val_step = step
                        print(
                            '[VALID] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f'
                            % (step, loss, p, r, a, e))
                        valid_writer.add_summary(val_summary, step)
                        valid_writer.flush()

                if step % FLAGS.chkpt_iter == 0:
                    save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str(
                        step) + '.ckpt'
                    save_path = saver.save(mon_sess, save_loc)
                    print('Model saved in path: %s' % save_path)

            except tf.errors.OutOfRangeError:
                print('OUT OF DATA - ENDING')
                # now finished training (either train or validation has run out)
                train_writer.add_summary(summary, step)
                train_writer.flush()
                if not FLAGS.val_data_path == '':  # skip validation if no path
                    valid_writer.add_summary(val_summary, val_step)
                    valid_writer.flush()
                save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str(
                    step) + '-end.ckpt'
                save_path = saver.save(mon_sess, save_loc)
                print('Model saved in path: %s' % save_path)
                break