Exemple #1
0
 def used_arg_scope(self, use_batch_norm, renorm, weight_decay):
     return nets_arg_scope(is_training=self.training,
                           use_batch_norm=use_batch_norm,
                           renorm=renorm,
                           batch_norm_decay=0.99,
                           renorm_decay=0.99,
                           weight_decay=weight_decay)
Exemple #2
0
    def used_arg_scope(self, use_batch_norm, renorm, weight_decay):
        """The slim argument scope that is used for main computations.

        In my experiences normally it includes batch normalization, proper
        initialization and weight regularization. `self.batch_stat` is
        particularly used to indicate whether to use moving mean/variance
        or batch statistics for batch normalization. For training we must
        use batch statistics so it's always `True`. For testing this is
        often `False` but sometimes the probability distribution in the test
        domain can be very different from the training domain and in this
        case this should better be `True` to guarantee a sensible performance.
        However, dropouts are then also affected (in `slim.dropout` dropouts
        are applied when `self.batch_stat` is `True`).

        Args:
            use_batch_norm: Whether to do batch normalization or not.
            renorm: Whether to do batch renormalization or not. I've in
                fact never used it.
            weigh_decay: The weight regularization coefficient.

        Returns:
            An argument scope to be used for model computations.
        """
        return nets_arg_scope(
            is_training=self.batch_stat,
            use_batch_norm=use_batch_norm,
            renorm=renorm,
            weight_decay=weight_decay)
Exemple #3
0
    def used_arg_scope(self, batch_stat, use_batch_norm):
        """The slim argument scope that is used for main computations.

        Args:
            use_batch_norm: Whether to do batch normalization or not.
            renorm: Whether to do batch renormalization or not. I've in
                fact never used it.

        Returns:
            An argument scope to be used for model computations.
        """
        return nets_arg_scope(is_training=batch_stat,
                              use_batch_norm=use_batch_norm)
def reconstruct(image_path,
                train_dir,
                CAE_architecture,
                log_dir=None,
                image_size=299,
                channels=3):
    """Reconstruct a single image."""
    with tf.Graph().as_default():

        image_string = tf.gfile.FastGFile(image_path, 'r').read()
        _, image_ext = os.path.splitext(image_path)

        if image_ext in ['.jpg', '.jpeg']:
            image = tf.image.decode_jpeg(image_string, channels=channels)
        elif image_ext == '.png':
            image = tf.image.decode_png(image_string, channels=channels)
        else:
            raise ValueError('image format not supported, must be jpg or png')

        processed_image = inception_preprocessing.preprocess_image(
            image, image_size, image_size, is_training=False)
        processed_images = tf.expand_dims(processed_image, 0)

        with slim.arg_scope(nets_arg_scope(is_training=False)):
            reconstructions, _ = CAE_architecture(processed_images)
        reconstruction = tf.squeeze(reconstructions)

        tf.summary.image('input', processed_images)
        tf.summary.image('reconstruction', reconstructions)
        summary_op = tf.summary.merge_all()

        if log_dir is not None:
            fw = tf.summary.FileWriter(log_dir)

        checkpoint_path = tf.train.latest_checkpoint(train_dir)
        saver = tf.train.Saver(tf.model_variables())

        with tf.Session() as sess:
            saver.restore(sess, checkpoint_path)
            reconstruction, summaries = sess.run([reconstruction, summary_op])
            if log_dir is not None:
                fw.add_summary(summaries)
            return reconstruction
Exemple #5
0
def train_mapping(tfrecord_dir,
                  checkpoints_dir_color,
                  checkpoints_dir_depth,
                  log_dir,
                  number_of_steps=None,
                  number_of_epochs=5,
                  batch_size=24,
                  save_summaries_step=5,
                  do_test=True,
                  dropout_keep_prob=0.8,
                  initial_learning_rate=0.005,
                  lr_decay_steps=100,
                  lr_decay_rate=0.8):

    if not tf.gfile.Exists(log_dir):
        tf.gfile.MakeDirs(log_dir)

    image_size = 299

    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)

        with tf.name_scope('Data_provider'):
            dataset = get_split_color_depth('train', tfrecord_dir)
            images_color_train, images_depth_train, _ = \
                load_batch_color_depth(
                    dataset, height=image_size, width=image_size,
                    batch_size=batch_size)

            dataset_test = get_split_color_depth('validation', tfrecord_dir)
            images_color_test, images_depth_test, _ = \
                load_batch_color_depth(
                    dataset_test, height=image_size, width=image_size,
                    batch_size=batch_size)

        training = tf.placeholder(tf.bool, shape=(), name='training')
        images_color = tf.cond(training, lambda: images_color_train,
                               lambda: images_color_test)
        images_depth = tf.cond(training, lambda: images_depth_train,
                               lambda: images_depth_test)

        if number_of_steps is None:
            number_of_steps = int(
                np.ceil(dataset.num_samples * number_of_epochs / batch_size))

        with slim.arg_scope(nets_arg_scope(is_training=training)):
            with tf.variable_scope('Color', values=[images_color]):
                net_color, _ = inception_v4.inception_v4_base(images_color)
                net_color = inception_feature(net_color)

            with tf.variable_scope('Depth', values=[images_depth]):
                net_depth, _ = inception_v4.inception_v4_base(images_depth)
                net_depth = inception_feature(net_depth)

            mapping = slim.fully_connected(net_color,
                                           net_depth.get_shape()[1].value,
                                           activation_fn=None,
                                           scope='Mapping')

        tf.losses.mean_squared_error(mapping, net_depth)
        total_loss = tf.losses.get_total_loss()

        # Create the global step for monitoring training
        global_step = tf.train.get_or_create_global_step()

        # Exponentially decaying learning rate
        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=lr_decay_steps,
            decay_rate=lr_decay_rate,
            staircase=True)

        # Optimizer and train op
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = slim.learning.create_train_op(
            total_loss,
            optimizer,
            variables_to_train=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='Mapping'))

        # Track moving mean and moving varaince
        try:
            last_moving_mean = [
                v for v in tf.model_variables()
                if v.op.name.endswith('moving_mean')
            ][0]
            last_moving_variance = [
                v for v in tf.model_variables()
                if v.op.name.endswith('moving_variance')
            ][0]
            tf.summary.histogram('batch_norm/last_layer/moving_mean',
                                 last_moving_mean)
            tf.summary.histogram('batch_norm/last_layer/moving_variance',
                                 last_moving_variance)
        except IndexError:
            pass

        tf.summary.scalar('learning_rate', learning_rate)
        tf.summary.scalar('losses/train/total_loss', total_loss)
        tf.summary.image('train/color', images_color)
        tf.summary.image('train/depth', images_depth)
        summary_op = tf.summary.merge_all()

        ls_test_summary = tf.summary.scalar('losses/test/total_loss',
                                            total_loss)
        imgs_test_color_summary = tf.summary.image('test/color', images_color)
        imgs_test_depth_summary = tf.summary.image('test/depth', images_depth)
        test_summary_op = tf.summary.merge([
            ls_test_summary, imgs_test_color_summary, imgs_test_depth_summary
        ])

        sv = tf.train.Supervisor(logdir=log_dir,
                                 summary_op=None,
                                 init_fn=get_init_fn(checkpoints_dir_color,
                                                     checkpoints_dir_depth))

        with sv.managed_session() as sess:
            for step in xrange(number_of_steps):
                if (step + 1) % save_summaries_step == 0:
                    loss, _, summaries = train_step(sess, train_op,
                                                    sv.global_step, summary_op)
                    sv.summary_computed(sess, summaries)
                    if do_test:
                        ls, summaries_test = sess.run(
                            [total_loss, test_summary_op],
                            feed_dict={training: False})
                        tf.logging.info('Current Test Loss: %s', ls)
                        sv.summary_computed(sess, summaries_test)
                else:
                    loss = train_step(sess, train_op, sv.global_step)[0]

            tf.logging.info('Finished training. Final Loss: %s', loss)
            tf.logging.info('Saving model to disk now.')
            sv.saver.save(sess, sv.save_path, global_step=sv.global_step)