Exemplo n.º 1
0
def run_training():
    # construct the graph
    with tf.Graph().as_default():

        # specify the training data file location
        trainfiles = []

        for fi in TRAIN_FILES:
            trainfiles.append(os.path.join(FLAGS.data_dir, fi))

            # trainfile = os.path.join(FLAGS.data_dir, TRAIN_FILE)

        # read the images and labels
        x, y_ = nn.inputs(batch_size=FLAGS.batch_size,
                          num_epochs=FLAGS.num_epochs,
                          filenames=trainfiles,
                          ifeval=False)
        keep_prob = tf.placeholder(tf.float32)

        z_placeholder = tf.placeholder(tf.float32,
                                       [FLAGS.batch_size, z_dimensions])

        # run inference on the images
        y_conv = nn.inference(x, np.array([65, 65, 65]), keep_prob,
                              FLAGS.batch_size)

        # calculate the loss from the results of inference and the labels
        loss = nn.loss(y_conv, y_)

        # caculate the accuracy
        accuracy = nn.evaluation(y_conv, y_)

        # setup the training operations
        train_op = nn.training(loss, FLAGS.learning_rate, FLAGS.decay_steps,
                               FLAGS.decay_rate)

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
                                                   sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _, l, acc = sess.run(
                        [train_op, loss, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print(
                            'OUTPUT: Step %d: loss = %.3f (%.3f sec), accuracy = %.3f'
                            % (step, l, duration, acc))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 500 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       'model.ckpt')
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, 'model.ckpt')

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()
Exemplo n.º 2
0
def run_training():

    if FLAGS.cluster:
        with open(FLAGS.cluster) as data_file:
            cluster_spec = json.load(data_file)

        cluster = tf.train.ClusterSpec(cluster_spec)

        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=FLAGS.task_index)

        if FLAGS.job_name == "ps":
            server.join()

    # construct the graph
    with tf.Graph().as_default():

        size = np.array(FLAGS.img_size)

        # read the images and labels to encode for the generator network 'fake'
        _x, _y_ = nn.inputs(
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            filenames=[FLAGS.tf_records],
            size=size,
            namescope="input_generator",
        )

        keep_prob = tf.placeholder(tf.float32)

        ps_device = "/gpu:0"
        w_device = "/gpu:0"
        # run the generator network on the 'fake' input images (encode/decode)
        with tf.variable_scope("generator") as scope:
            if (len(size) == 4):
                gen_x = nn.generator(_x,
                                     keep_prob,
                                     FLAGS.batch_size,
                                     ps_device=ps_device,
                                     w_device=w_device,
                                     is_training=True)
            else:
                gen_x = nn.generator2d(_x,
                                       keep_prob,
                                       FLAGS.batch_size,
                                       ps_device=ps_device,
                                       w_device=w_device,
                                       is_training=True)

        _y_ = tf.layers.batch_normalization(_y_)

        # calculate the loss for the generator, i.e., trick the discriminator
        loss_g = nn.loss(gen_x, _y_)
        tf.summary.scalar("loss_g", loss_g)

        # setup the training operations
        train_op_g = nn.training_adam(loss_g, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_discriminator")

        # calculate the accuracy

        accuracy = nn.evaluation(gen_x, _y_, name="accuracy")

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        # with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            now = datetime.now()
            summary_writer = tf.summary.FileWriter(
                os.path.join(
                    FLAGS.checkpoint_dir,
                    FLAGS.model_name + "-" + now.strftime("%Y%m%d-%H%M%S")),
                sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _g, l_g, acc = sess.run(
                        [train_op_g, loss_g, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print('OUTPUT: Step', step, 'loss:', l_g, 'accuracy:',
                              acc, 'duraction:', duration)
                        #print('OUTPUT: Step %d: loss_g = %.3f, accuracy = %.3f, (%.3f sec)' % (step, l_g, acc, duration))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op,
                                               feed_dict={keep_prob: 0.5})
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 1000 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       FLAGS.model_name)
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                        print('MODEL:', checkpoint_path)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, FLAGS.model_name)

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()
Exemplo n.º 3
0
def run_training():
 
# construct the graph
    with tf.Graph().as_default():

# specify the training data file location
        trainfile = os.path.join(FLAGS.data_dir, TRAIN_FILE)

# read the images and labels
        images, labels = nn.inputs(batch_size=FLAGS.batch_size,
                                num_epochs=FLAGS.num_epochs,
                                filename=trainfile)

# run inference on the images
        results = nn.inference(images)

# calculate the loss from the results of inference and the labels
        loss = nn.loss(results, labels)

# setup the training operations
        train_op = nn.training(loss, FLAGS.learning_rate, FLAGS.decay_steps,
                       FLAGS.decay_rate)

# setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

# init to setup the initial values of the weights
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

# setup a saver for saving checkpoints
        saver = tf.train.Saver()
    
# create the session
        sess = tf.Session()

# specify where to write the log files for import to TensorBoard
        summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,  
                            sess.graph)

# initialize the graph
        sess.run(init_op)

# setup the coordinato and threadsr.  Used for multiple threads to read data.  
# Not strictly required since we don't have a lot of data but typically 
# using multiple threads to read data improves performance
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# loop will continue until we run out of input training cases
        try:
            step = 0
            while not coord.should_stop():

# start time and run one training iteration
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time

# print some output periodically
                if step % 100 == 0:
                    print('OUTPUT: Step %d: loss = %.3f (%.3f sec)' % (step, 
                                                               loss_value,
                                                               duration))
# output some data to the log files for tensorboard
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

# less frequently output checkpoint files.  Used for evaluating the model
                if step % 1000 == 0:
                    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 
                                                     'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                step += 1

# quit after we run out of input files to read
        except tf.errors.OutOfRangeError:
            print('OUTPUT: Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
                                                              step))
            checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 
                                              'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        finally:
            coord.request_stop()
    
# shut down the threads gracefully
        coord.join(threads)
        sess.close()
Exemplo n.º 4
0
def run_training():

    if FLAGS.cluster:
        with open(FLAGS.cluster) as data_file:
            cluster_spec = json.load(data_file)

        cluster = tf.train.ClusterSpec(cluster_spec)

        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=FLAGS.task_index)

        if FLAGS.job_name == "ps":
            server.join()

    # construct the graph
    with tf.Graph().as_default():

        size = np.array([33, 33, 33, 1])

        # read the images and labels to encode for the generator network 'fake'
        fake_x, fake_y_ = nn.inputs(batch_size=FLAGS.batch_size,
                                    num_epochs=FLAGS.num_epochs,
                                    filenames=[FLAGS.generator],
                                    namescope="input_generator",
                                    size=size)

        # read the images and labels for the discriminator network 'real'
        real_x, real_y_ = nn.inputs(batch_size=FLAGS.batch_size,
                                    num_epochs=FLAGS.num_epochs,
                                    filenames=[FLAGS.discriminator],
                                    size=size,
                                    namescope="input_discriminator")

        keep_prob = tf.placeholder(tf.float32)

        ps_device = "/gpu:0"
        w_device = "/gpu:0"
        # run the generator network on the 'fake' input images (encode/decode)
        with tf.variable_scope("generator") as scope:
            gen_x = nn.generator(fake_x,
                                 size,
                                 keep_prob,
                                 FLAGS.batch_size,
                                 ps_device=ps_device,
                                 w_device=w_device)

        with tf.variable_scope("discriminator") as scope:
            # run the discriminator network on the generated images
            gen_y_conv = nn.discriminator(gen_x,
                                          size,
                                          keep_prob,
                                          FLAGS.batch_size,
                                          ps_device=ps_device,
                                          w_device=w_device)

            scope.reuse_variables()
            # run the discriminator network on the real images
            real_y_conv = nn.discriminator(real_x,
                                           size,
                                           keep_prob,
                                           FLAGS.batch_size,
                                           ps_device=ps_device,
                                           w_device=w_device)

        # self.d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
        # self.d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
        # self.g_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

        # self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
        # self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)

        # self.d_loss = self.d_loss_real + self.d_loss_fake

        # calculate the loss for the real images
        loss_real_d = nn.loss(real_y_conv, real_y_)
        tf.summary.scalar("loss_real_d", loss_real_d)
        # calculate the loss for the fake images
        loss_fake_d = nn.loss(gen_y_conv, fake_y_)
        tf.summary.scalar("loss_fake_d", loss_fake_d)
        # calculate the loss for the discriminator
        loss_d = loss_real_d + loss_fake_d
        tf.summary.scalar("loss_d", loss_d)

        # calculate the loss for the generator, i.e., trick the discriminator
        loss_g = nn.loss(gen_y_conv, real_y_)
        tf.summary.scalar("loss_g", loss_g)

        vars_train = tf.trainable_variables()

        vars_gen = [var for var in vars_train if 'generator' in var.name]
        vars_dis = [var for var in vars_train if 'discriminator' in var.name]

        for var in vars_gen:
            print('gen', var.name)

        for var in vars_dis:
            print('dis', var.name)

        # setup the training operations
        train_op_d = nn.training_adam(loss_d, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_discriminator",
                                      vars_dis)

        train_op_g = nn.training_adam(loss_g, FLAGS.learning_rate, FLAGS.beta1,
                                      FLAGS.beta2, FLAGS.epsilon,
                                      FLAGS.use_locking, "train_generator",
                                      vars_gen)

        # caculate the accuracy
        accreal = nn.evaluation(real_y_conv, real_y_, name="accuracy_real")
        tf.summary.scalar(accreal.op.name, accreal)

        accfake = nn.evaluation(gen_y_conv, fake_y_, name="accuracy_fake")
        tf.summary.scalar(accfake.op.name, accfake)

        accuracy = (accreal + accfake) / 2.0
        tf.summary.scalar("accuracy", accuracy)

        # setup the summary ops to use TensorBoard
        summary_op = tf.summary.merge_all()

        # init to setup the initial values of the weights
        #init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # create the session
        # with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        with tf.Session() as sess:

            sess.run(init_op)
            # setup a saver for saving checkpoints
            saver = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
                                                   sess.graph)

            # setup the coordinato and threadsr.  Used for multiple threads to read data.
            # Not strictly required since we don't have a lot of data but typically
            # using multiple threads to read data improves performance
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            start_training_time = time.time()
            # loop will continue until we run out of input training cases
            try:
                step = 0
                while not coord.should_stop():
                    # start time and run one training iteration
                    start_time = time.time()

                    _g, _d, l_g, l_d, acc = sess.run(
                        [train_op_g, train_op_d, loss_g, loss_d, accuracy],
                        feed_dict={keep_prob: 0.5})  # Update the discriminator

                    duration = time.time() - start_time

                    # print some output periodically
                    if step % 20 == 0:
                        print(
                            'OUTPUT: Step %d: loss_g = %.3f, loss_d = %3.f, accuracy = %.3f, (%.3f sec)'
                            % (step, l_g, l_d, acc, duration))
                        # output some data to the log files for tensorboard
                        summary_str = sess.run(summary_op,
                                               feed_dict={keep_prob: 0.5})
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    # less frequently output checkpoint files.  Used for evaluating the model
                    if step % 1000 == 0:
                        checkpoint_path = os.path.join(check_save,
                                                       'model.ckpt')
                        saver.save(sess,
                                   save_path=checkpoint_path,
                                   global_step=step)
                    step += 1

            # quit after we run out of input files to read
            except tf.errors.OutOfRangeError:
                print('OUTPUT: Done training for %d epochs, %d steps.' %
                      (FLAGS.num_epochs, step))
                checkpoint_path = os.path.join(check_save, 'model.ckpt')

                saver.save(sess, checkpoint_path, global_step=step)

            finally:
                coord.request_stop()

            # shut down the threads gracefully
            coord.join(threads)
            sess.close()
            end_training_time = time.time()