Ejemplo n.º 1
0
def train(trainSetFileNames, testSetFileNames):    
    with open(trainSetFileNames) as f:
        trainSet = f.read().splitlines()
    with open(testSetFileNames) as f:
        testSet = f.read().splitlines()
        
    with tf.Graph().as_default():
        # Global step variable for tracking processes.
        global_step = tf.Variable(0, trainable=False)
        
        # Train and Test Set feeds.
        trainSet_batch, trainLabel_batch = input_pipeline(trainSet, FLAGS.batch_size)
        testSet_batch, testLabel_batch = input_pipeline(testSet, FLAGS.batch_size)

        # Placeholder to switch between train and test sets.
        dataShape = [FLAGS.batch_size, FLAGS.input_size, FLAGS.input_size, FLAGS.input_depth]
        labelShape = [FLAGS.batch_size, LSPGlobals.TotalLabels]
        example_batch = tf.placeholder(tf.float32, shape=dataShape)
        label_batch = tf.placeholder(tf.float32, shape=labelShape)
        keepProb = tf.placeholder(tf.float32)
        
        # Build a Graph that computes the logits predictions from the inference model.
        logits = LSPModels.inference(example_batch, FLAGS.batch_size, keepProb=keepProb)
        
        # Calculate loss.
        loss, meanPixelError = LSPModels.loss(logits, label_batch)
        
        # Build a Graph that trains the model with one batch of examples and updates the model parameters.
        train_op = LSPModels.train(loss, global_step)
        
        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())
        
        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        
        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        
        with tf.Session() as sess:
            # Start populating the filename queue.
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            
            sess.run(init)
            
            stepinit = 0
            ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                stepinit = sess.run(global_step)
            else:
                print("No checkpoint found...")
                
            
            summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph=sess.graph)
            
            for step in range(stepinit, FLAGS.max_steps):
                
                start_time = time.time()
                train_examplebatch, train_labelbatch = sess.run([trainSet_batch, trainLabel_batch])
                feeddict = {example_batch: train_examplebatch, 
                            label_batch: train_labelbatch,
                            keepProb: 0.75}
                _, PxErrValue = sess.run([train_op, meanPixelError], feed_dict=feeddict)
                duration = time.time() - start_time
                                
                if step % 10 == 0:
                    num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)
                    
                    format_str = ('%s: step %d, MeanPixelError = %.1f pixels (%.1f examples/sec; %.3f sec/batch)')
                    print (format_str % (datetime.now(), step, PxErrValue,
                                         examples_per_sec, sec_per_batch))
                
                if (step % 50 == 0) and (step != 0):
                    summary_str = sess.run(summary_op, feed_dict=feeddict)
                    summary_writer.add_summary(summary_str, step)

                if (step % 100 == 0) and (step != 0):
                    test_examplebatch, test_labelbatch = sess.run([testSet_batch, testLabel_batch])
                    producedlabels, PxErrValue_Test = sess.run([logits,meanPixelError], 
                                             feed_dict={example_batch: test_examplebatch, 
                                                        label_batch: test_labelbatch,
                                                        keepProb: 1})
                    
                    draw(test_examplebatch[0,...], producedlabels[0,...], FLAGS.drawing_dir, step/100)
                    print('Test Set MeanPixelError: %.1f pixels' %PxErrValue_Test)
                          
                          
                # Save the model checkpoint periodically.
                if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
        
            coord.request_stop()
            coord.join(threads)
Ejemplo n.º 2
0
def train():
    with tf.Graph().as_default():
        # Global step variable for tracking processes.
        global_step = tf.Variable(0, trainable=False)

        # Prepare data batches
        train_set_batch, train_label_batch = inputs(is_train=True)
        validation_set_batch, validation_label_batch = inputs(is_train=False)

        # Placeholder to switch between train and test sets.
        image_batch = tf.placeholder(tf.float32,
                                     shape=[
                                         FLAGS.batch_size, FLAGS.input_size,
                                         FLAGS.input_size, FLAGS.input_depth
                                     ])
        label_batch = tf.placeholder(
            tf.int32, shape=[FLAGS.batch_size, LSPGlobals.TotalLabels])
        keep_probability = tf.placeholder(tf.float32)

        # Build a Graph that computes the logits predictions from the inference model.
        logits = LSPModels.inference(image_batch, keep_prob=keep_probability)

        # Calculate loss.
        loss, mean_pixel_error = LSPModels.loss(logits, label_batch)

        # Build a Graph that trains the model with one batch of examples and updates the model parameters.
        train_op = LSPModels.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            # Start populating the filename queue.
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            sess.run(init)

            step_init = 0
            checkpoint = tf.train.get_checkpoint_state(FLAGS.train_dir)
            if checkpoint and checkpoint.model_checkpoint_path:
                saver.restore(sess, checkpoint.model_checkpoint_path)
                step_init = sess.run(global_step)
            else:
                print("No checkpoint found...")

            summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                                   graph=sess.graph)

            for step in range(step_init, FLAGS.max_steps):

                start_time = time.time()
                images, labels = sess.run([train_set_batch, train_label_batch])
                feed_dict = {
                    image_batch: images,
                    label_batch: labels,
                    keep_probability: 0.6
                }
                _, pixel_error_value = sess.run([train_op, mean_pixel_error],
                                                feed_dict=feed_dict)
                duration = time.time() - start_time

                if not step == 0:
                    # Print current results.
                    if step % 50 == 0:
                        num_examples_per_step = FLAGS.batch_size
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = float(duration)

                        format_str = '%s: step %d, MeanPixelError = %.1f pixels (%.1f examples/sec; %.3f sec/batch)'
                        print(format_str %
                              (datetime.now(), step, pixel_error_value,
                               examples_per_sec, sec_per_batch))

                    # Check results for validation set
                    if step % 500 == 0:
                        images, labels = sess.run(
                            [validation_set_batch, validation_label_batch])
                        feed_dict = {
                            image_batch: images,
                            label_batch: labels,
                            keep_probability: 1
                        }
                        produced_labels, pixel_error_value = sess.run(
                            [logits, mean_pixel_error], feed_dict=feed_dict)

                        draw(images[0, ...], produced_labels[0, ...],
                             FLAGS.drawing_dir, step / 500)
                        print('Test Set MeanPixelError: %.1f pixels' %
                              pixel_error_value)

                    # Add summary to summary writer
                    if step % 1000 == 0:
                        summary_str = sess.run(summary_op, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)

                    # Save the model checkpoint periodically.
                    if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
                        checkpoint_path = os.path.join(FLAGS.train_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                        print('Model checkpoint saved for step %d' % step)

            coord.request_stop()
            coord.join(threads)