示例#1
0
def run_training():
    # construct the graph
    with tf.Graph().as_default():

        # specify the training data file location
        trainfiles = []

        for fi in TRAIN_FILES:
            trainfiles.append(os.path.join(FLAGS.data_dir, fi))

            # trainfile = os.path.join(FLAGS.data_dir, TRAIN_FILE)

        # read the images and labels
        x, y_ = nn.inputs(batch_size=FLAGS.batch_size,
                          num_epochs=FLAGS.num_epochs,
                          filenames=trainfiles,
                          ifeval=False)
        keep_prob = tf.placeholder(tf.float32)

        z_placeholder = tf.placeholder(tf.float32,
                                       [FLAGS.batch_size, z_dimensions])

        # run inference on the images
        y_conv = nn.inference(x, np.array([65, 65, 65]), keep_prob,
                              FLAGS.batch_size)

        # calculate the loss from the results of inference and the labels
        loss = nn.loss(y_conv, y_)

        # caculate the accuracy
        accuracy = nn.evaluation(y_conv, y_)

        # setup the training operations
        train_op = nn.training(loss, FLAGS.learning_rate, FLAGS.decay_steps,
                               FLAGS.decay_rate)

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
                                                   sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _, l, acc = sess.run(
                        [train_op, loss, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print(
                            'OUTPUT: Step %d: loss = %.3f (%.3f sec), accuracy = %.3f'
                            % (step, l, duration, acc))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 500 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       'model.ckpt')
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, 'model.ckpt')

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()
示例#2
0
def run_training():
 
# construct the graph
    with tf.Graph().as_default():

# specify the training data file location
        trainfile = os.path.join(FLAGS.data_dir, TRAIN_FILE)

# read the images and labels
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                                num_epochs=FLAGS.num_epochs,
                                filename=trainfile)

# run inference on the images
        results = nn.inference(images)

# calculate the loss from the results of inference and the labels
        loss = nn.loss(results, labels)

# setup the training operations
        train_op = nn.training(loss, FLAGS.learning_rate, FLAGS.decay_steps,
                       FLAGS.decay_rate)

# setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

# init to setup the initial values of the weights
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

# setup a saver for saving checkpoints
        saver = tf.train.Saver()
    
# create the session
        sess = tf.Session()

# specify where to write the log files for import to TensorBoard
        summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,  
                            sess.graph)

# initialize the graph
        sess.run(init_op)

# setup the coordinato and threadsr.  Used for multiple threads to read data.  
# Not strictly required since we don't have a lot of data but typically 
# using multiple threads to read data improves performance
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# loop will continue until we run out of input training cases
        try:
            step = 0
            while not coord.should_stop():

# start time and run one training iteration
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time

# print some output periodically
                if step % 100 == 0:
                    print('OUTPUT: Step %d: loss = %.3f (%.3f sec)' % (step, 
                                                               loss_value,
                                                               duration))
# output some data to the log files for tensorboard
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

# less frequently output checkpoint files.  Used for evaluating the model
                if step % 1000 == 0:
                    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 
                                                     'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                step += 1

# quit after we run out of input files to read
        except tf.errors.OutOfRangeError:
            print('OUTPUT: Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
                                                              step))
            checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 
                                              'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        finally:
            coord.request_stop()
    
# shut down the threads gracefully
        coord.join(threads)
        sess.close()
示例#3
0
def run_eval():

    # Run evaluation on the input data set
    with tf.Graph().as_default() as g:

    # Get images and labels for the MRI data
        eval_data = FLAGS.eval_data == 'eval'

# choose whether to evaluate the training set or the evaluation set
        evalfile = os.path.join(FLAGS.data_dir, 
                    VALIDATION_FILE if eval_data else TRAIN_FILE)

# read the proper data set
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                           num_epochs=1, filename=evalfile)

    # Build a Graph that computes the logits predictions from the
    # inference model.  We'll use a prior graph built by the training
        logits = nn.inference(images)

    # Calculate predictions.
        int_area, label_area, example_area = nn.evaluation(logits, labels)

    # setup the initialization of variables
        local_init = tf.initialize_local_variables()

    # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

# create the saver and session
        saver = tf.train.Saver()
        sess = tf.Session()

# init the local variables
        sess.run(local_init)

        while True:

    # read in the most recent checkpointed graph and weights    
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)     
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            else:
                print('No checkpoint file found in %s' % FLAGS.checkpoint_dir)
                return
 
# start up the threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:

# true_count accumulates the correct predictions
                int_sum = 0
                label_sum = 0
                example_sum = 0 
#                true_count = 0
                step = 0
                while not coord.should_stop():

# run a single iteration of evaluation
#                    predictions = sess.run([top_k_op])
                    ii, ll, ee = sess.run([int_area, label_area, example_area])
                    int_sum += ii
                    label_sum += ll
                    example_sum += ee
# aggregate correct predictions 
#                    true_count += np.sum(predictions)
                    step += 1

# uncomment below line for debugging
#                    print("step ii, ll, ee, iI, lL, eE", 
#                             step, ii, ll, ee, int_sum,
#                              label_sum, example_sum)
        
            except tf.errors.OutOfRangeError:
# print and output the relevant prediction accuracy
#                precision = true_count / ( step * 256.0 * 256 )
                precision = (2.0 * int_sum) / ( label_sum + example_sum )
                print('OUTPUT: %s: Dice metric = %.3f' % (datetime.now(), precision))
                print('OUTPUT: %d images evaluated from file %s' % (step, evalfile))

# create summary to show in TensorBoard
                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='2Dice metric', simple_value=precision)
                summary_writer.add_summary(summary, global_step)

            finally:
                coord.request_stop()
        
# shutdown gracefully
            coord.join(threads)
             
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
            sess.close()
示例#4
0
def run_eval():

    # Run evaluation on the input data set
    with tf.Graph().as_default() as g:

    # Get images and labels for the MRI data
        eval_data = FLAGS.eval_data == 'eval'

# choose whether to evaluate the training set or the evaluation set
        evalfile = os.path.join(FLAGS.data_dir, 
                    VALIDATION_FILE if eval_data else TRAIN_FILE)

# read the proper data set
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                           num_epochs=1, filename=evalfile)

    # Build a Graph that computes the logits predictions from the
    # inference model.  We'll use a prior graph built by the training
        logits = nn.inference(images)

    # Calculate predictions.
        top_k_op = nn.evaluation(logits, labels)

    # setup the initialization of variables
        local_init = tf.initialize_local_variables()

    # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

# create the saver and session
        saver = tf.train.Saver()
        sess = tf.Session()

# init the local variables
        sess.run(local_init)

        while True:

    # read in the most recent checkpointed graph and weights    
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)     
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            else:
                print('No checkpoint file found in %s' % FLAGS.checkpoint_dir)
                return
 
# start up the threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:

# true_count accumulates the correct predictions
                true_count = 0
                step = 0
                while not coord.should_stop():

# run a single iteration of evaluation
                    predictions = sess.run([top_k_op])

# aggregate correct predictions 
                    true_count += np.sum(predictions)
                    step += 1

# uncomment below line for debugging
#                    print("step truecount", step, true_count)
        
            except tf.errors.OutOfRangeError:
# print and output the relevant prediction accuracy
                precision = true_count / ( step * 256.0 * 256 )
                print('OUTPUT: %s: precision = %.3f' % (datetime.now(), precision))
                print('OUTPUT: %d images evaluated from file %s' % (step, evalfile))

# create summary to show in TensorBoard
                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='1cnn_accuracy', simple_value=precision)
                summary_writer.add_summary(summary, global_step)

            finally:
                coord.request_stop()
        
# shutdown gracefully
            coord.join(threads)
             
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
            sess.close()
示例#5
0
def run_training():

    if FLAGS.cluster:
        with open(FLAGS.cluster) as data_file:
            cluster_spec = json.load(data_file)

        cluster = tf.train.ClusterSpec(cluster_spec)

        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=FLAGS.task_index)

        if FLAGS.job_name == "ps":
            server.join()

    # construct the graph
    with tf.Graph().as_default():

        size = np.array(FLAGS.img_size)

        # read the images and labels to encode for the generator network 'fake'
        _x, _y_ = nn.inputs(
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            filenames=[FLAGS.tf_records],
            size=size,
            namescope="input_generator",
        )

        keep_prob = tf.placeholder(tf.float32)

        ps_device = "/gpu:0"
        w_device = "/gpu:0"
        # run the generator network on the 'fake' input images (encode/decode)
        with tf.variable_scope("generator") as scope:
            if (len(size) == 4):
                gen_x = nn.generator(_x,
                                     keep_prob,
                                     FLAGS.batch_size,
                                     ps_device=ps_device,
                                     w_device=w_device,
                                     is_training=True)
            else:
                gen_x = nn.generator2d(_x,
                                       keep_prob,
                                       FLAGS.batch_size,
                                       ps_device=ps_device,
                                       w_device=w_device,
                                       is_training=True)

        _y_ = tf.layers.batch_normalization(_y_)

        # calculate the loss for the generator, i.e., trick the discriminator
        loss_g = nn.loss(gen_x, _y_)
        tf.summary.scalar("loss_g", loss_g)

        # setup the training operations
        train_op_g = nn.training_adam(loss_g, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_discriminator")

        # calculate the accuracy

        accuracy = nn.evaluation(gen_x, _y_, name="accuracy")

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        # with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            now = datetime.now()
            summary_writer = tf.summary.FileWriter(
                os.path.join(
                    FLAGS.checkpoint_dir,
                    FLAGS.model_name + "-" + now.strftime("%Y%m%d-%H%M%S")),
                sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _g, l_g, acc = sess.run(
                        [train_op_g, loss_g, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print('OUTPUT: Step', step, 'loss:', l_g, 'accuracy:',
                              acc, 'duraction:', duration)
                        #print('OUTPUT: Step %d: loss_g = %.3f, accuracy = %.3f, (%.3f sec)' % (step, l_g, acc, duration))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op,
                                               feed_dict={keep_prob: 0.5})
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 1000 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       FLAGS.model_name)
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                        print('MODEL:', checkpoint_path)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, FLAGS.model_name)

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()
示例#6
0
def run_eval():
    # Run evaluation on the input data set
    with tf.Graph().as_default() as g:

        # Get images and labels for the MRI data
        eval_data = FLAGS.eval_data == 'eval'

        # specify the training data file location
        testfiles = []

        for fi in TEST_FILES:
            testfiles.append(os.path.join(FLAGS.data_dir, fi))

        # read the proper data set
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                                   num_epochs=1,
                                   filenames=testfiles,
                                   ifeval=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.  We'll use a prior graph built by the training
        z_placeholder = tf.placeholder(tf.float32,
                                       [FLAGS.batch_size, z_dimensions])

        _, Gz, _ = nn.inference(images, z_placeholder, z_dimensions,
                                FLAGS.batch_size, 'yo.txt', False)

        # Calculate predictions.
        # pred, lab, acc = nn.evaluation(logits, labels)

        # setup the initialization of variables
        local_init = tf.initialize_local_variables()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

        # create the saver and session
        saver = tf.train.Saver()
        sess = tf.Session()

        # init the local variables
        sess.run(local_init)

        while True:

            # read in the most recent checkpointed graph and weights
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found in %s' % FLAGS.checkpoint_dir)
                return

            # start up the threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:

                step = 0
                while not coord.should_stop():
                    # run a single iteration of evaluation
                    # print('OUTPUT: Step %d:' % step)
                    z_batch = np.random.normal(-1, 1, size=[1, z_dimensions])
                    gz = sess.run([Gz], feed_dict={z_placeholder: z_batch})

                    try:
                        data_save = FLAGS.checkpoint_dir + 'Data'
                        if not os.path.isdir(data_save):
                            os.makedirs(data_save)
                        print(np.shape(gz[0][0][:, :, 0]))
                        im = gz[0][0][:, :, 0]
                        nrrd.write(
                            os.path.join(data_save,
                                         'Gz_' + step.__str__() + '.nrrd'),
                            np.reshape(im, (195, 233)))
                    except Exception as e:
                        print('Unable to save data to', 'test.npy', ':', e)
                        raise

                    step += 1

            except tf.errors.OutOfRangeError:

                print('OUTPUT: %s: ' % (datetime.now()))

                print('OUTPUT: %d images evaluated from file %s & %s' %
                      (step, testfiles[0], testfiles[1]))

                # create summary to show in TensorBoard
                #summary = tf.Summary()
                summary = sess.run(summary_op)
                #summary.ParseFromString(sess.run(summary_op))

                summary_writer.add_summary(summary, global_step)

            finally:
                coord.request_stop()

            # shutdown gracefully
            coord.join(threads)

            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
            sess.close()
示例#7
0
def run_eval(images=None, labels=None):

    # Run evaluation on the input data set
    with tf.Graph().as_default() as g:

        # Get images and labels for the MRI data
        eval_data = FLAGS.eval_data == 'eval'

        # choose whether to evaluate the training set or the evaluation set
        evalfile = os.path.join(FLAGS.data_dir,
                                VALIDATION_FILE if eval_data else TRAIN_FILE)

        # read the proper data set
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                                   num_epochs=1,
                                   filename=evalfile)

        # Build a Graph that computes the logits predictions from the
        # inference model.  We'll use a prior graph built by the training
        logits = nn.inference(images)

        # Calculate predictions.
        int_area, label_area, example_area = nn.evaluation(logits, labels)

        # setup the initialization of variables
        local_init = tf.local_variables_initializer()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        # create the saver and session
        saver = tf.train.Saver()
        sess = tf.Session()

        # init the local variables
        sess.run(local_init)

        count = 0
        while True:
            # read in the most recent checkpointed graph and weights
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found in %s' % FLAGS.checkpoint_dir)
                return

# start up the threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:

                # true_count accumulates the correct predictions
                int_sum = 0
                label_sum = 0
                example_sum = 0
                #                true_count = 0
                step = 0
                while not coord.should_stop():

                    # run a single iteration of evaluation
                    #                    predictions = sess.run([top_k_op])

                    ii, ll, ee = sess.run([int_area, label_area, example_area])
                    int_sum += ii
                    label_sum += ll
                    example_sum += ee
                    # aggregate correct predictions
                    #                    true_count += np.sum(predictions)
                    step += 1

                    # uncomment below line for debugging
                    #                    print("step ii, ll, ee, iI, lL, eE",
                    #                             step, ii, ll, ee, int_sum,
                    #                              label_sum, example_sum)
                    if __debug__ == True:
                        print 'count', count
                        images_, labels_, logits_ = sess.run(
                            [images, labels, logits])
                        images_, labels_, logtis_ = map(
                            np.squeeze, [images_, labels_, logits_])
                        images_, labels_ = map(change_mode, [images_, labels_])

                        print np.shape(logits_)
                        fig = plt.figure()
                        fig.add_subplot(1, 3, 1)
                        plt.imshow(images_)
                        fig.add_subplot(1, 3, 2)
                        plt.imshow(labels_)
                        fig.add_subplot(1, 3, 3)
                        logits_ = onehot2cls_image(logtis_)
                        plt.imshow(logits_)
                        plt.savefig('./eval_' + str(step) + '.png')
                        plt.show()

                        count += 1
            except tf.errors.OutOfRangeError:
                # print and output the relevant prediction accuracy
                #                precision = true_count / ( step * 256.0 * 256 )
                precision = (2.0 * int_sum) / (label_sum + example_sum)
                print('OUTPUT: %s: Dice metric = %.3f' %
                      (datetime.now(), precision))
                print('OUTPUT: %d images evaluated from file %s' %
                      (step, evalfile))

                # create summary to show in TensorBoard
                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='2Dice metric', simple_value=precision)
                summary_writer.add_summary(summary, global_step)

            finally:
                coord.request_stop()


# shutdown gracefully
            coord.join(threads)

            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
            sess.close()
示例#8
0
def run_training():

    if FLAGS.cluster:
        with open(FLAGS.cluster) as data_file:
            cluster_spec = json.load(data_file)

        cluster = tf.train.ClusterSpec(cluster_spec)

        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=FLAGS.task_index)

        if FLAGS.job_name == "ps":
            server.join()

    # construct the graph
    with tf.Graph().as_default():

        size = np.array([33, 33, 33, 1])

        # read the images and labels to encode for the generator network 'fake'
        fake_x, fake_y_ = nn.inputs(batch_size=FLAGS.batch_size,
                                    num_epochs=FLAGS.num_epochs,
                                    filenames=[FLAGS.generator],
                                    namescope="input_generator",
                                    size=size)

        # read the images and labels for the discriminator network 'real'
        real_x, real_y_ = nn.inputs(batch_size=FLAGS.batch_size,
                                    num_epochs=FLAGS.num_epochs,
                                    filenames=[FLAGS.discriminator],
                                    size=size,
                                    namescope="input_discriminator")

        keep_prob = tf.placeholder(tf.float32)

        ps_device = "/gpu:0"
        w_device = "/gpu:0"
        # run the generator network on the 'fake' input images (encode/decode)
        with tf.variable_scope("generator") as scope:
            gen_x = nn.generator(fake_x,
                                 size,
                                 keep_prob,
                                 FLAGS.batch_size,
                                 ps_device=ps_device,
                                 w_device=w_device)

        with tf.variable_scope("discriminator") as scope:
            # run the discriminator network on the generated images
            gen_y_conv = nn.discriminator(gen_x,
                                          size,
                                          keep_prob,
                                          FLAGS.batch_size,
                                          ps_device=ps_device,
                                          w_device=w_device)

            scope.reuse_variables()
            # run the discriminator network on the real images
            real_y_conv = nn.discriminator(real_x,
                                           size,
                                           keep_prob,
                                           FLAGS.batch_size,
                                           ps_device=ps_device,
                                           w_device=w_device)

        # self.d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
        # self.d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
        # self.g_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

        # self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
        # self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)

        # self.d_loss = self.d_loss_real + self.d_loss_fake

        # calculate the loss for the real images
        loss_real_d = nn.loss(real_y_conv, real_y_)
        tf.summary.scalar("loss_real_d", loss_real_d)
        # calculate the loss for the fake images
        loss_fake_d = nn.loss(gen_y_conv, fake_y_)
        tf.summary.scalar("loss_fake_d", loss_fake_d)
        # calculate the loss for the discriminator
        loss_d = loss_real_d + loss_fake_d
        tf.summary.scalar("loss_d", loss_d)

        # calculate the loss for the generator, i.e., trick the discriminator
        loss_g = nn.loss(gen_y_conv, real_y_)
        tf.summary.scalar("loss_g", loss_g)

        vars_train = tf.trainable_variables()

        vars_gen = [var for var in vars_train if 'generator' in var.name]
        vars_dis = [var for var in vars_train if 'discriminator' in var.name]

        for var in vars_gen:
            print('gen', var.name)

        for var in vars_dis:
            print('dis', var.name)

        # setup the training operations
        train_op_d = nn.training_adam(loss_d, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_discriminator",
                                      vars_dis)

        train_op_g = nn.training_adam(loss_g, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_generator",
                                      vars_gen)

        # caculate the accuracy
        accreal = nn.evaluation(real_y_conv, real_y_, name="accuracy_real")
        tf.summary.scalar(accreal.op.name, accreal)

        accfake = nn.evaluation(gen_y_conv, fake_y_, name="accuracy_fake")
        tf.summary.scalar(accfake.op.name, accfake)

        accuracy = (accreal + accfake) / 2.0
        tf.summary.scalar("accuracy", accuracy)

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        # with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
                                                   sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _g, _d, l_g, l_d, acc = sess.run(
                        [train_op_g, train_op_d, loss_g, loss_d, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print(
                            'OUTPUT: Step %d: loss_g = %.3f, loss_d = %3.f, accuracy = %.3f, (%.3f sec)'
                            % (step, l_g, l_d, acc, duration))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op,
                                               feed_dict={keep_prob: 0.5})
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 1000 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       'model.ckpt')
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, 'model.ckpt')

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()