Exemplo n.º 1
0
def training():

    pretrained_weights = './pretrain/vgg16.npy'
    data_dir = './data/cifar10_data/cifar-10-batches-bin'
    train_log_dir = './log/train/'
    val_log_dir = './log/val/'

    with tf.name_scope('input'):
        images_train, labels_train = input_data.read_cifar10(data_dir, is_train=True,
                                                                        batch_size=BATCH_SIZE, shuffle=True)
        images_val, labels_val = input_data.read_cifar10(data_dir, is_train=False,
                                                                        batch_size=BATCH_SIZE, shuffle=False)

    image_holder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
    label_holder = tf.placeholder(tf.int32, shape=[BATCH_SIZE, N_CLASSES])

    logits = vgg.VGG16(image_holder, N_CLASSES, 0.8)
    loss = layers.loss(logits, label_holder)
    accuracy = layers.accuracy(logits, label_holder)

    global_steps = tf.Variable(0, name='global_step', trainable=False)
    train_op = layers.optimize(loss, LEARNING_RATE, global_steps)

    saver = tf.train.Saver(tf.global_variables())

    # Refenrnce: https://stackoverflow.com/questions/35413618/tensorflow-placeholder-error-when-using-tf-merge-all-summaries
    summary_op = tf.summary.merge_all()
    # summary_op = tf.summary.merge([loss_summary, accuracy_summary], tf.GraphKeys.SUMMARIES)

    # The main thread
    init = tf.global_variables_initializer()
    sess = tf.InteractiveSession()
    sess.run(init)

    print('########################## Start Training ##########################')

    layers.load_with_skip(pretrained_weights, sess, ['fc6', 'fc7', 'fc8'])

    # Coordinate the relationship between threads
    # Reference: http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/threading_and_queues.html
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    train_summary_writer = tf.summary.FileWriter(train_log_dir, graph=sess.graph)
    val_summary_writer = tf.summary.FileWriter(val_log_dir, graph=sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break
            # start_time  = time .time()

            train_images, train_labels = sess.run([images_train, labels_train])
            _, train_loss, train_acc, summary_str = sess.run([train_op, loss, accuracy, summary_op],
                                                feed_dict={image_holder: train_images, label_holder: train_labels})
            # duration = time.time() - start_time

            if step % 50 == 0 or (step + 1) == MAX_STEP:
                print('step %d, loss = %.4f, accuracy = %.4f%%' % (step, train_loss, train_acc))
                #summary_str = sess.run(summary_op)
                train_summary_writer.add_summary(summary_str, step)

            if step % 200 == 0 or (step + 1) == MAX_STEP:
                val_images, val_labels = sess.run([images_val, labels_val])
                val_loss, val_acc = sess.run([loss, accuracy],
                                             feed_dict={image_holder: val_images, label_holder: val_labels})
                print('step %d, val loss = %.2f, val accuracy = %.2f%%' % (step, val_loss, val_acc))

                #summary_str2 = sess.run(summary_op)
                val_summary_writer.add_summary(summary_str, step)

            # Why not use global_step=global_steps instead of step ???
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        coord.request_stop()

    coord.request_stop()
    coord.join(threads)

    sess.close()
Exemplo n.º 2
0
def optimize(loss, global_step):
  return layers_lib.optimize(
      loss, global_step, FLAGS.max_grad_norm, FLAGS.learning_rate,
      FLAGS.learning_rate_decay_factor, FLAGS.sync_replicas,
      FLAGS.replicas_to_aggregate, FLAGS.task)
Exemplo n.º 3
0
def optimize(loss, global_step=None):
    return layers_lib.optimize(loss, global_step, FLAGS.max_grad_norm,
                               FLAGS.learning_rate,
                               FLAGS.learning_rate_decay_factor)
Exemplo n.º 4
0
def optimize(loss, global_step):
    return layers_lib.optimize(loss, global_step, FLAGS.max_grad_norm,
                               FLAGS.learning_rate,
                               FLAGS.learning_rate_decay_factor,
                               FLAGS.sync_replicas,
                               FLAGS.replicas_to_aggregate, FLAGS.task)
Exemplo n.º 5
0
def training():
    pretrained_weights = './pretrain/vgg16.npy'

    train_log_dir = './log_dr50000/train/'
    val_log_dir = './log_dr50000/val/'

    with tf.name_scope('input'):
        images_train, labels_train = dr5_input.input_data(True, BATCH_SIZE)
        images_val, labels_val = dr5_input.input_data(False, BATCH_SIZE)

    image_holder = tf.placeholder(tf.float32,
                                  shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
    label_holder = tf.placeholder(tf.int32, shape=[BATCH_SIZE, N_CLASSES])

    logits = vgg.VGG16(image_holder, N_CLASSES, 0.5)
    loss = layers.loss(logits, label_holder)
    accuracy = layers.accuracy(logits, label_holder)

    global_steps = tf.Variable(0, name='global_step', trainable=False)
    LEARNING_RATE = tf.train.exponential_decay(start_rate,
                                               global_steps,
                                               decay_steps,
                                               deacy_rate,
                                               staircase=True)
    train_op = layers.optimize(loss, LEARNING_RATE, global_steps)

    saver = tf.train.Saver(tf.global_variables())

    summary_op = tf.summary.merge_all()

    # The main thread
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    sess = tf.InteractiveSession()
    sess.run(init)

    print(
        '########################## Start Training ##########################')

    layers.load_with_skip(pretrained_weights, sess, ['fc6', 'fc7', 'fc8'])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    train_summary_writer = tf.summary.FileWriter(train_log_dir,
                                                 graph=sess.graph)
    val_summary_writer = tf.summary.FileWriter(val_log_dir, graph=sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break
            # start_time  = time .time()

            train_images, train_labels = sess.run([images_train, labels_train])
            _, train_loss, train_acc, summary_str = sess.run(
                [train_op, loss, accuracy, summary_op],
                feed_dict={
                    image_holder: train_images,
                    label_holder: train_labels
                })
            # duration = time.time() - start_time

            if step % 50 == 0 or (step + 1) == MAX_STEP:
                print('step %d, loss = %.4f, accuracy = %.4f%%' %
                      (step, train_loss, train_acc))
                train_summary_writer.add_summary(summary_str, step)

            if step % 200 == 0 or (step + 1) == MAX_STEP:
                val_images, val_labels = sess.run([images_val, labels_val])
                val_loss, val_acc = sess.run([loss, accuracy],
                                             feed_dict={
                                                 image_holder: val_images,
                                                 label_holder: val_labels
                                             })
                print('step %d, val loss = %.2f, val accuracy = %.2f%%' %
                      (step, val_loss, val_acc))
                val_summary_writer.add_summary(summary_str, step)

            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                lr = sess.run(LEARNING_RATE)
                print("step %d, learning_rate= %f" % (step, lr))

    except tf.errors.OutOfRangeError:
        coord.request_stop()

    coord.request_stop()
    coord.join(threads)

    sess.close()
    def run(self, run_type):

        is_training = True if run_type == 'train' else False

        self.log('{} epoch: {}'.format(run_type, self.epoch))

        image_filenames, label_filenames = self.dataset.load_filenames(
            run_type)

        global_step = tf.Variable(1, name='global_step', trainable=False)

        images, labels = inputs.load_batches(image_filenames,
                                             label_filenames,
                                             shape=self.dataset.SHAPE,
                                             batch_size=self.batch_size,
                                             resize_shape=self.dataset.SHAPE,
                                             crop_shape=(256, 512),
                                             augment=True)

        with tf.name_scope('labels'):
            color_labels = util.colorize(labels, self.dataset.augmented_labels)
            labels = tf.cast(labels, tf.int32)
            ignore_mask = util.get_ignore_mask(labels,
                                               self.dataset.augmented_labels)
            tf.summary.image('label', color_labels, 1)
            tf.summary.image('weights', tf.cast(ignore_mask * 255, tf.uint8),
                             1)

        tf.summary.image('image', images, 1)

        logits = self.model.inference(images,
                                      num_classes=self.num_classes,
                                      is_training=is_training)

        with tf.name_scope('outputs'):
            predictions = layers.predictions(logits)
            color_predictions = util.colorize(predictions,
                                              self.dataset.augmented_labels)
            tf.summary.image('prediction', color_predictions, 1)

        # Add some metrics
        with tf.name_scope('metrics'):
            accuracy_op, accuracy_update_op = tf.contrib.metrics.streaming_accuracy(
                predictions, labels, weights=ignore_mask)
            mean_iou_op, mean_iou_update_op = tf.contrib.metrics.streaming_mean_iou(
                predictions,
                labels,
                num_classes=self.num_classes,
                weights=ignore_mask)

        if is_training:
            loss_op = layers.loss(logits,
                                  labels,
                                  mask=ignore_mask,
                                  weight_decay=self.weight_decay)
            train_op = layers.optimize(loss_op,
                                       learning_rate=self.learning_rate,
                                       global_step=global_step)

        # Merge all summaries into summary op
        summary_op = tf.summary.merge_all()

        # Create restorer for restoring
        saver = tf.train.Saver()

        # Initialize session and local variables (for input pipeline and metrics)
        sess = tf.Session()
        sess.run(tf.local_variables_initializer())

        if self.checkpoint is None:
            sess.run(tf.global_variables_initializer())
            self.log('{} {} from scratch.'.format(run_type, self.model_name))
        else:
            start_time = time.time()
            saver.restore(sess, self.checkpoint)
            duration = time.time() - start_time
            self.log('{} from previous checkpoint {:s} ({:.2f}s)'.format(
                run_type, self.checkpoint, duration))

        # Create summary writer
        summary_path = os.path.join(self.model_path, run_type)
        step_writer = tf.summary.FileWriter(summary_path, sess.graph)

        # Start filling the input queues
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        num_examples = self.dataset.NUM_TRAIN_EXAMPLES if is_training else self.dataset.NUM_VALID_EXAMPLES

        for local_step in range(num_examples // self.batch_size):

            # Take time!
            start_time = time.time()

            if is_training:
                _, loss, accuracy, mean_iou, summary = sess.run([
                    train_op, loss_op, accuracy_update_op, mean_iou_update_op,
                    summary_op
                ])
                duration = time.time() - start_time
                self.log('Epoch: {} train step: {} loss: {:.4f} accuracy: {:.2f}% duration: {:.2f}s' \
                    .format(self.epoch, local_step + 1, loss, accuracy * 100, duration))
            else:
                accuracy, mean_iou, summary = sess.run(
                    [accuracy_update_op, mean_iou_update_op, summary_op])
                duration = time.time() - start_time
                self.log('Epoch: {} eval step: {} accuracy: {:.2f}% duration: {:.2f}s'\
                    .format(self.epoch, local_step + 1, accuracy * 100, duration))

            # Save summary and print stats
            step_writer.add_summary(summary,
                                    global_step=global_step.eval(session=sess))

        # Write additional epoch summaries
        epoch_writer = tf.summary.FileWriter(summary_path)
        epoch_summaries = []
        if is_training:
            epoch_summaries.append(
                tf.summary.scalar('params/weight_decay', self.weight_decay))
            epoch_summaries.append(
                tf.summary.scalar('params/learning_rate', self.learning_rate))
        epoch_summaries.append(
            tf.summary.scalar('params/batch_size', self.batch_size))
        epoch_summaries.append(
            tf.summary.scalar('metrics/accuracy', accuracy_op))
        epoch_summaries.append(
            tf.summary.scalar('metrics/mean_iou', mean_iou_op))
        epoch_summary_op = tf.summary.merge(epoch_summaries)
        summary = sess.run(epoch_summary_op)
        epoch_writer.add_summary(summary, global_step=self.epoch)

        # Save after each epoch when training
        if is_training:
            checkpoint_path = os.path.join(self.model_path,
                                           self.model_name + '.checkpoint')
            start_time = time.time()
            self.checkpoint = saver.save(sess,
                                         checkpoint_path,
                                         global_step=self.epoch)
            duration = time.time() - start_time
            self.log('Model saved as {:s} ({:.2f}s)'.format(
                self.checkpoint, duration))

        # Stop queue runners and reset the graph
        coord.request_stop()
        coord.join(threads)
        sess.close()
        tf.reset_default_graph()