Beispiel #1
0
    def _build_graph(self):
        from tf_utils.layers import conv2d, max_pool, rescale_bilinear, avg_pool

        def layer_width(layer: int):  # number of channels (features per pixel)
            return min([4 * 4**(layer + 1), 64])

        input_shape = [None] + list(self.input_shape)
        output_shape = input_shape[:3] + [self.class_count]

        # Input image and labels placeholders
        input = tf.placeholder(tf.float32, shape=input_shape)
        target = tf.placeholder(tf.float32, shape=output_shape)

        # Downsampled input (to improve speed at the cost of accuracy)
        h = rescale_bilinear(input, 0.5)

        # Hidden layers
        h = conv2d(h, 3, layer_width(0))
        h = tf.nn.relu(h)
        for l in range(1, self.conv_layer_count):
            h = max_pool(h, 2)
            h = conv2d(h, 3, layer_width(l))
            h = tf.nn.relu(h)

        # Pixelwise softmax classification and label upscaling
        logits = conv2d(h, 1, self.class_count)
        probs = tf.nn.softmax(logits)
        probs = tf.image.resize_bilinear(probs, output_shape[1:3])

        # Loss
        clipped_probs = tf.clip_by_value(probs, 1e-10, 1.0)
        ts = lambda x: x[:, :, :, 1:] if self.class0_unknown else x
        cost = -tf.reduce_mean(ts(target) * tf.log(ts(clipped_probs)))

        # Optimization
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        training_step = optimizer.minimize(cost)

        # Dense predictions and labels
        preds, dense_labels = tf.argmax(probs, 3), tf.argmax(target, 3)

        # Other evaluation measures
        self._n_accuracy = tf.reduce_mean(
            tf.cast(tf.equal(preds, dense_labels), tf.float32))

        return AbstractModel.EssentialNodes(
            input=input,
            target=target,
            probs=probs,
            loss=cost,
            training_step=training_step)
Beispiel #2
0
    def _build_graph(self, learning_rate, epoch, is_training):
        from layers import conv

        # Input image and labels placeholders
        input_shape = [None] + list(self.input_shape)
        output_shape = [None, self.class_count]
        input = tf.placeholder(tf.float32, shape=input_shape)
        target = tf.placeholder(tf.float32, shape=output_shape)

        # Hidden layers
        h = layers_exp.rbf_resnet(input,
                                  is_training=is_training,
                                  base_width=self.base_width,
                                  widening_factor=self.widening_factor,
                                  group_lengths=self.group_lengths)

        # Global pooling and softmax classification
        h = tf.reduce_mean(h, axis=[1, 2], keep_dims=True)
        logits = conv(h, 1, self.class_count)
        logits = tf.reshape(logits, [-1, self.class_count])
        probs = tf.nn.softmax(logits)

        # Loss
        clipped_probs = tf.clip_by_value(probs, 1e-10, 1.0)
        loss = -tf.reduce_mean(target * tf.log(clipped_probs))

        # Regularization
        w_vars = filter(lambda x: 'weights' in x.name, tf.global_variables())
        loss += self.weight_decay * regularization.l2_regularization(w_vars)

        # Optimization
        optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
        training_step = optimizer.minimize(loss)

        # Dense predictions and labels
        preds, dense_labels = tf.argmax(probs, 1), tf.argmax(target, 1)

        # Other evaluation measures
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(preds, dense_labels), tf.float32))

        #writer = tf.summary.FileWriter('logs', self._sess.graph)

        return AbstractModel.EssentialNodes(input=input,
                                            target=target,
                                            probs=probs,
                                            loss=loss,
                                            training_step=training_step,
                                            evaluation={'accuracy': accuracy})
Beispiel #3
0
    def _build_graph(self, learning_rate, epoch, is_training):
        from tensorflow.contrib import layers
        from tf_utils.layers import conv2d, max_pool, rescale_bilinear, avg_pool, bn_relu
        from tf_utils.losses import multiclass_hinge_loss

        def get_ortho_penalty():
            vars = tf.contrib.framework.get_variables('')
            filt = lambda x: 'conv' in x.name and 'weights' in x.name
            weight_vars = list(filter(filt, vars))
            loss = tf.constant(0.0)
            for v in weight_vars:
                m = tf.reshape(v, (-1, v.shape[3].value))
                d = tf.matmul(
                    m, m, True) - tf.eye(v.shape[3].value) / v.shape[3].value
                loss += tf.reduce_sum(d**2)
            return loss

        input_shape = [None] + list(self.input_shape)
        output_shape = [None, self.class_count]

        # Input image and labels placeholders
        input = tf.placeholder(tf.float32, shape=input_shape, name='input')
        target = tf.placeholder(tf.float32, shape=output_shape, name='target')

        # L2 regularization
        weight_decay = tf.constant(self.weight_decay, dtype=tf.float32)

        # Hidden layers
        h = input
        with tf.contrib.framework.arg_scope(
            [layers.conv2d],
                kernel_size=5,
                data_format='NHWC',
                padding='SAME',
                activation_fn=tf.nn.relu,
                weights_initializer=layers.variance_scaling_initializer(),
                weights_regularizer=layers.l2_regularizer(weight_decay)):
            h = layers.conv2d(h, 16, scope='convrelu1')
            h = layers.max_pool2d(h, 2, 2, scope='pool1')
            h = layers.conv2d(h, 32, scope='convrelu2')
            h = layers.max_pool2d(h, 2, 2, scope='pool2')
        with tf.contrib.framework.arg_scope(
            [layers.fully_connected],
                activation_fn=tf.nn.relu,
                weights_initializer=layers.variance_scaling_initializer(),
                weights_regularizer=layers.l2_regularizer(weight_decay)):
            h = layers.flatten(h, scope='flatten3')
            h = layers.fully_connected(h, 512, scope='fc3')

        self._print_vars()

        # Softmax classification
        logits = layers.fully_connected(h,
                                        self.class_count,
                                        activation_fn=None,
                                        scope='logits')
        probs = tf.nn.softmax(logits, name='probs')

        # Loss
        mhl = lambda t, lo: 0.1 * multiclass_hinge_loss(t, lo)
        sce = tf.losses.softmax_cross_entropy
        loss = (mhl if self.use_multiclass_hinge_loss else sce)(target, logits)
        loss = loss + tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        if self.ortho_penalty > 0:
            loss += self.ortho_penalty * get_ortho_penalty()

        # Optimization
        optimizer = tf.train.AdamOptimizer(learning_rate)
        training_step = optimizer.minimize(loss)

        # Dense predictions and labels
        preds, dense_labels = tf.argmax(probs, 1), tf.argmax(target, 1)

        # Other evaluation measures
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(preds, dense_labels), tf.float32))

        return AbstractModel.EssentialNodes(input=input,
                                            target=target,
                                            probs=probs,
                                            loss=loss,
                                            training_step=training_step,
                                            evaluation={'accuracy': accuracy})