Example #1
0
def _combine_conv(inputs,
                  layer,
                  training,
                  out_channels,
                  padding='VALID',
                  is_batchnorm=True):
    conv = conv2d(inputs,
                  kernel=3,
                  out_channels=out_channels,
                  stride=1,
                  padding=padding,
                  name='combine_conv' + layer)
    print('conv layer: %s, shape: %s' % (layer, conv.get_shape()))
    if is_batchnorm:
        bn = tf.layers.batch_normalization(conv,
                                           axis=1,
                                           center=True,
                                           scale=False,
                                           training=training,
                                           name='bn' + layer)
    else:
        bn = conv
    relu = tf.nn.relu(bn, name='relu' + layer)

    print('layer: %s, shape: %s' % (layer, relu.get_shape()))
    return relu
Example #2
0
def _combine_conv(inputs, layer, training, in_channels, out_channels, kernel=3, padding='SAME', is_batchnorm=False):

    def _weights_initializer(weights):
        if weights.shape != (kernel, kernel, in_channels, out_channels):
            weights = np.reshape(weights, (kernel, kernel, in_channels, out_channels))
        return tf.constant_initializer(weights)

    conv = conv2d(
        inputs, kernel=kernel,
        out_channels=out_channels, stride=1,
        weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005),
        weights_initializer=_weights_initializer(init_weights['conv' + layer + '_W']),
        biases_initializer=tf.constant_initializer(init_weights['conv' + layer + '_b']),
        padding=padding, name='combine_conv' + layer
    )
    print('conv layer: %s, shape: %s' % (layer, conv.get_shape()))
    if is_batchnorm:
        bn = tf.layers.batch_normalization(conv, axis=1, center=True, scale=False,
                                           training=training, name='bn' + layer)
    else:
        bn = conv
    relu = tf.nn.relu(bn, name='relu' + layer)

    print('layer: %s, shape: %s' % (layer, relu.get_shape()))
    return relu
def _combine_conv(inputs,
                  layer,
                  training,
                  kernel=5,
                  padding='VALID',
                  out_channels=32):
    conv = conv2d(inputs,
                  kernel=kernel,
                  out_channels=out_channels,
                  stride=1,
                  padding=padding,
                  name='combine_conv' + layer)
    print('conv layer: %s, shape: %s' % (layer, conv.get_shape()))
    bn = tf.layers.batch_normalization(conv,
                                       axis=1,
                                       center=True,
                                       scale=False,
                                       training=training,
                                       name='bn' + layer)
    relu = tf.nn.relu(bn, name='relu' + layer)
    pool = tf.layers.max_pooling2d(inputs=relu,
                                   pool_size=[2, 2],
                                   strides=[2, 2],
                                   data_format="channels_first",
                                   name='pool' + layer)

    print('layer: %s, shape: %s\n' % (layer, pool.get_shape()))
    return pool
def _combine_deconv(inputs,
                    layer,
                    training,
                    conv_val,
                    kernel_deconv=5,
                    kernel_conv=3,
                    deconv_out_channels=32,
                    conv_out_channels=32,
                    conv_padding='SAME'):
    deconv1 = deconv(inputs,
                     kernel=kernel_deconv,
                     out_channels=deconv_out_channels,
                     stride=2,
                     data_format='NCHW',
                     activation_fn=tf.nn.relu,
                     name='deconv' + layer)
    print('deconv layer: %s, deconv shape: %s' % (layer, deconv1.get_shape()))

    if conv_val is not None:
        print('conv layer to be concatenated, shape: %s' %
              (conv_val.get_shape()))
        concat1 = tf.concat([conv_val, deconv1], axis=1, name='concat' + layer)
    else:
        concat1 = deconv1
    conv = conv2d(concat1,
                  kernel=kernel_conv,
                  out_channels=conv_out_channels,
                  stride=1,
                  data_format='NCHW',
                  padding=conv_padding,
                  name='deconv_conv' + layer)
    print('deconv layer: %s, conv shape: %s' % (layer, conv.get_shape()))
    bn = tf.layers.batch_normalization(conv,
                                       axis=-1,
                                       center=True,
                                       scale=False,
                                       training=training,
                                       name='deconv_bn' + layer)
    print('deconv layer: %s, final shape: %s\n' % (layer, bn.get_shape()))
    return bn
Example #5
0
def inference(inputs, num_classes, training=False, name='unet'):
    with tf.variable_scope(name) as scope:
        conv1 = conv2d(inputs,
                       kernel=3,
                       out_channels=32,
                       stride=1,
                       padding='SAME',
                       activation_fn=tf.nn.relu,
                       normalizer_fn=tf.contrib.layers.batch_norm,
                       name='relu_conv1')
        print('conv1 shape: %s' % conv1.get_shape())
        pool1 = tf.nn.max_pool(conv1,
                               ksize=[1, 1, 2, 2],
                               strides=[1, 1, 2, 2],
                               padding='VALID',
                               data_format='NCHW',
                               name='pool1')
        print('pool1 shape: %s' % pool1.get_shape())

        conv2 = conv2d(pool1,
                       kernel=3,
                       out_channels=64,
                       stride=1,
                       padding='SAME',
                       activation_fn=tf.nn.relu,
                       normalizer_fn=tf.contrib.layers.batch_norm,
                       name='relu_conv2')
        print('conv2 shape: %s' % conv2.get_shape())
        pool2 = tf.nn.max_pool(conv2,
                               ksize=[1, 1, 2, 2],
                               strides=[1, 1, 2, 2],
                               padding='VALID',
                               data_format='NCHW',
                               name='pool2')
        print('pool2 shape: %s' % pool2.get_shape())

        conv3 = conv2d(pool2,
                       kernel=3,
                       out_channels=128,
                       stride=1,
                       padding='SAME',
                       activation_fn=tf.nn.relu,
                       normalizer_fn=tf.contrib.layers.batch_norm,
                       name='relu_conv3')
        print('conv3 shape: %s' % conv3.get_shape())
        pool3 = tf.nn.max_pool(conv3,
                               ksize=[1, 1, 2, 2],
                               strides=[1, 1, 2, 2],
                               padding='VALID',
                               data_format='NCHW',
                               name='pool3')
        print('pool3 shape: %s' % pool3.get_shape())

        pool3_dropout = tf.layers.dropout(pool3,
                                          0.5,
                                          training=training,
                                          name='pool3_dropout')

        deconv1 = deconv(pool3_dropout,
                         kernel=2,
                         out_channels=128,
                         stride=2,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         normalizer_fn=tf.contrib.layers.batch_norm,
                         name='deconv1')
        print('deconv1 shape: %s' % deconv1.get_shape())
        deconv1_conv = conv2d(deconv1,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              normalizer_fn=tf.contrib.layers.batch_norm,
                              name='deconv1_conv')
        print('deconv1_conv shape: %s' % deconv1_conv.get_shape())
        concat1 = tf.concat([pool2, deconv1_conv], axis=1, name='concat1')
        dropout1 = tf.layers.dropout(concat1,
                                     0.5,
                                     training=training,
                                     name='dropout1')
        concat1_conv = conv2d(dropout1,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              normalizer_fn=tf.contrib.layers.batch_norm,
                              name='concat1_conv')
        print('concat1_conv shape: %s' % concat1_conv.get_shape())

        deconv2 = deconv(concat1_conv,
                         kernel=2,
                         out_channels=128,
                         stride=2,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         normalizer_fn=tf.contrib.layers.batch_norm,
                         name='deconv2')
        print('deconv2 shape: %s' % deconv2.get_shape())
        deconv2_conv = conv2d(deconv2,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              normalizer_fn=tf.contrib.layers.batch_norm,
                              name='deconv2_conv')
        print('deconv2_conv shape: %s' % deconv2_conv.get_shape())
        concat2 = tf.concat([pool1, deconv2_conv], axis=1, name='concat2')
        dropout2 = tf.layers.dropout(concat2,
                                     0.5,
                                     training=training,
                                     name='dropout2')
        concat2_conv = conv2d(dropout2,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              normalizer_fn=tf.contrib.layers.batch_norm,
                              name='concat2_conv')
        print('concat2_conv shape: %s' % concat2_conv.get_shape())

        deconv3 = deconv(concat2_conv,
                         kernel=2,
                         out_channels=128,
                         stride=2,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         normalizer_fn=tf.contrib.layers.batch_norm,
                         name='deconv3')
        print('deconv3 shape: %s' % deconv3.get_shape())
        deconv3_conv = conv2d(deconv3,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              name='deconv3_conv')
        print('deconv3_conv shape: %s' % deconv3_conv.get_shape())

        class_conv = conv2d(deconv3_conv,
                            kernel=3,
                            out_channels=num_classes,
                            stride=1,
                            padding='SAME',
                            name='class_conv')
        print('class_conv shape: %s' % class_conv.get_shape())

        label_logits = tf.transpose(class_conv, perm=[0, 2, 3, 1])
        label_logits = tf.check_numerics(
            label_logits, message="nan or inf from: label_logits")
        print('label_logits shape: %s' % label_logits.get_shape())
        return label_logits
def inference(inputs, num_classes, name='unet'):
    with tf.variable_scope(name) as scope:
        conv1 = conv2d(inputs,
                       kernel=5,
                       out_channels=128,
                       stride=1,
                       padding='VALID',
                       activation_fn=tf.nn.relu,
                       name='relu_conv1')
        print('conv1 shape: %s' % conv1.get_shape())
        # max_pool1 = tf.nn.max_pool(
        #     conv1,
        #     ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
        #     padding='VALID', data_format='NCHW', name='max_pool1'
        # )
        conv2 = conv2d(conv1,
                       kernel=5,
                       out_channels=128,
                       stride=1,
                       padding='VALID',
                       activation_fn=tf.nn.relu,
                       name='relu_conv2')
        print('conv2 shape: %s' % conv2.get_shape())
        # max_pool2 = tf.nn.max_pool(
        #     conv2,
        #     ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
        #     padding='VALID', data_format='NCHW', name='max_pool2'
        # )
        conv3 = conv2d(conv2,
                       kernel=3,
                       out_channels=256,
                       stride=2,
                       padding='VALID',
                       activation_fn=tf.nn.relu,
                       name='relu_conv3')
        print('conv3 shape: %s' % conv3.get_shape())
        # max_pool3 = tf.nn.max_pool(
        #     conv3,
        #     ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
        #     padding='VALID', data_format='NCHW', name='max_pool1'
        # )
        conv3 = tf.check_numerics(conv3, message="nan or inf from: conv3")

        deconv1 = deconv(conv3,
                         kernel=4,
                         out_channels=128,
                         stride=2,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         name='deconv1')
        print('deconv1 shape: %s' % deconv1.get_shape())
        concat1 = tf.concat([conv2, deconv1], axis=1, name='concat1')
        # print('concat1 shape: %s' % concat1.get_shape())
        deconv1_conv = conv2d(concat1,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              name='deconv1_conv')
        deconv2 = deconv(deconv1_conv,
                         kernel=5,
                         out_channels=128,
                         stride=1,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         name='deconv2')
        print('deconv2 shape: %s' % deconv2.get_shape())
        concat2 = tf.concat([conv1, deconv2], axis=1, name='concat2')
        # print('concat2 shape: %s' % concat2.get_shape())
        deconv2_conv = conv2d(concat2,
                              kernel=3,
                              out_channels=128,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              name='deconv2_conv')

        deconv3 = deconv(deconv2_conv,
                         kernel=5,
                         out_channels=num_classes,
                         stride=1,
                         data_format='NCHW',
                         activation_fn=tf.nn.relu,
                         name='deconv3')
        # print('deconv3 shape: %s' % deconv3.get_shape())
        deconv3_conv = conv2d(deconv3,
                              kernel=3,
                              out_channels=num_classes,
                              stride=1,
                              padding='SAME',
                              activation_fn=tf.nn.relu,
                              name='deconv3_conv')

        label_logits = tf.transpose(deconv3_conv, perm=[0, 2, 3, 1])
        label_logits = tf.check_numerics(
            label_logits, message="nan or inf from: label_logits")
        print('label_logits shape: %s' % label_logits.get_shape())
        return label_logits
def _decode(activations, capsule_num, coupling_coeffs, num_classes, batch_size,
            pool1, pool2, training):
    capsule_probs = tf.norm(activations,
                            axis=-1)  # # (b, 32, 4, 20, 8) -> (b, 32, 4, 20)
    caps_probs_tiled = tf.tile(tf.expand_dims(capsule_probs, -1),
                               [1, 1, 1, 1, num_classes])  # (b, 32, 4, 20, 2)
    # caps_probs_tiled = tf.check_numerics(caps_probs_tiled, message="nan or inf from: caps_probs_tiled")

    print('coupling_coeffs shape: %s' % coupling_coeffs.get_shape())
    activations_shape = activations.get_shape()
    height, width = activations_shape[2].value, activations_shape[3].value
    coupling_coeff_reshaped = tf.reshape(
        coupling_coeffs, [batch_size, capsule_num, height, width, num_classes
                          ])  # (b, 32, 4, 20, 2)
    # coupling_coeff_reshaped = tf.check_numerics(coupling_coeff_reshaped, message="nan or inf from: coupling_coeff_reshaped")

    primary_labels = tf.reduce_sum(coupling_coeff_reshaped * caps_probs_tiled,
                                   1)  # (b, 4, 20, 2)
    # class_labels = tf.Print(class_labels, [tf.constant("class_labels"), class_labels])
    # class_labels = tf.check_numerics(class_labels, message="nan or inf from: class_labels")
    # primary_labels = tf.reduce_sum(caps_probs_tiled, 1)
    # deconv1 = deconv(
    #     class_labels,
    #     kernel=3, out_channels=num_classes, stride=1,
    #     activation_fn=tf.nn.relu, name='deconv1'
    # )
    # deconv1 = tf.Print(deconv1, [tf.constant("deconv1"), deconv1])
    print('primary_labels shape: %s' % primary_labels.get_shape())
    primary_conv = conv2d(tf.transpose(primary_labels, perm=[0, 3, 1, 2]),
                          kernel=3,
                          out_channels=256,
                          stride=1,
                          padding='SAME',
                          activation_fn=tf.nn.relu,
                          name='primary_conv')
    print('primary_conv shape: %s' % primary_conv.get_shape())

    deconv1 = deconv(primary_conv,
                     kernel=3,
                     out_channels=128,
                     stride=1,
                     data_format='NCHW',
                     activation_fn=tf.nn.relu,
                     normalizer_fn=tf.contrib.layers.batch_norm,
                     name='deconv1')
    print('deconv1 shape: %s' % deconv1.get_shape())
    # deconv1_conv = conv2d(
    #     deconv1,
    #     kernel=3, out_channels=128, stride=1, padding='SAME',
    #     activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
    #     name='deconv1_conv'
    # )
    # print('deconv1_conv shape: %s' % deconv1_conv.get_shape())
    concat1 = tf.concat([pool2, deconv1], axis=1, name='concat1')
    dropout1 = tf.layers.dropout(concat1,
                                 0.5,
                                 training=training,
                                 name='dropout1')
    concat1_conv = conv2d(dropout1,
                          kernel=2,
                          out_channels=128,
                          stride=1,
                          padding='VALID',
                          activation_fn=tf.nn.relu,
                          normalizer_fn=tf.contrib.layers.batch_norm,
                          name='concat1_conv')
    print('concat1_conv shape: %s' % concat1_conv.get_shape())

    deconv2 = deconv(concat1_conv,
                     kernel=4,
                     out_channels=128,
                     stride=2,
                     data_format='NCHW',
                     activation_fn=tf.nn.relu,
                     normalizer_fn=tf.contrib.layers.batch_norm,
                     name='deconv2')
    print('deconv2 shape: %s' % deconv2.get_shape())
    # deconv2_conv = conv2d(
    #     deconv2,
    #     kernel=3, out_channels=128, stride=1, padding='SAME',
    #     activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
    #     name='deconv2_conv'
    # )
    # print('deconv2_conv shape: %s' % deconv2_conv.get_shape())
    concat2 = tf.concat([pool1, deconv2], axis=1, name='concat2')
    dropout2 = tf.layers.dropout(concat2,
                                 0.5,
                                 training=training,
                                 name='dropout2')
    concat2_conv = conv2d(dropout2,
                          kernel=2,
                          out_channels=128,
                          stride=1,
                          padding='VALID',
                          activation_fn=tf.nn.relu,
                          normalizer_fn=tf.contrib.layers.batch_norm,
                          name='concat2_conv')
    print('concat2_conv shape: %s' % concat2_conv.get_shape())

    deconv3 = deconv(concat2_conv,
                     kernel=4,
                     out_channels=128,
                     stride=2,
                     data_format='NCHW',
                     activation_fn=tf.nn.relu,
                     normalizer_fn=tf.contrib.layers.batch_norm,
                     name='deconv3')
    print('deconv3 shape: %s' % deconv3.get_shape())
    # deconv3_conv = conv2d(
    #     deconv3,
    #     kernel=3, out_channels=128, stride=1, padding='SAME',
    #     activation_fn=tf.nn.relu, name='deconv3_conv'
    # )
    # print('deconv3_conv shape: %s' % deconv3_conv.get_shape())

    class_conv = conv2d(deconv3,
                        kernel=3,
                        out_channels=num_classes,
                        stride=1,
                        padding='SAME',
                        name='class_conv')
    print('class_conv shape: %s' % class_conv.get_shape())

    label_logits = tf.transpose(class_conv, perm=[0, 2, 3, 1])
    print('label_logits shape: %s' % label_logits.get_shape())
    # label_logits = tf.Print(label_logits, [tf.constant("label_logits"), label_logits])
    return label_logits
def inference(inputs,
              num_classes,
              routing_ites=4,
              remake=False,
              training=False,
              name='capsnet_1d'):
    """

    :param inputs:
    :param num_classes:
    :param routing_ites:
    :param remake:
    :param name:
    :return:
    """

    with tf.variable_scope(name) as scope:
        inputs_shape = inputs.get_shape()
        batch_size = inputs_shape[0].value
        image_height = inputs_shape[2].value
        image_width = inputs_shape[3].value

        # ReLU Conv1
        # Images shape (b, 1, 24, 56) -> conv 5x5 filters, 32 output channels, strides 2 with padding, ReLU
        # nets -> (b, 256, 16, 48)
        print('inputs shape: %s' % inputs.get_shape())
        inputs = tf.check_numerics(inputs, message="nan or inf from: inputs")

        conv1 = conv2d(inputs,
                       kernel=3,
                       out_channels=32,
                       stride=1,
                       padding='SAME',
                       activation_fn=tf.nn.relu,
                       normalizer_fn=tf.contrib.layers.batch_norm,
                       name='relu_conv1')
        print('conv1 shape: %s' % conv1.get_shape())
        pool1 = tf.nn.max_pool(conv1,
                               ksize=[1, 1, 2, 2],
                               strides=[1, 1, 2, 2],
                               padding='VALID',
                               data_format='NCHW',
                               name='pool1')
        print('pool1 shape: %s' % pool1.get_shape())

        conv2 = conv2d(pool1,
                       kernel=3,
                       out_channels=64,
                       stride=1,
                       padding='SAME',
                       activation_fn=tf.nn.relu,
                       normalizer_fn=tf.contrib.layers.batch_norm,
                       name='relu_conv2')
        print('conv2 shape: %s' % conv2.get_shape())
        pool2 = tf.nn.max_pool(conv2,
                               ksize=[1, 1, 2, 2],
                               strides=[1, 1, 2, 2],
                               padding='VALID',
                               data_format='NCHW',
                               name='pool2')
        print('pool2 shape: %s' % pool2.get_shape())
        pool2_dropout = tf.layers.dropout(pool2,
                                          0.5,
                                          training=training,
                                          name='pool2_dropout')

        # conv3 = conv2d(
        #     pool2,
        #     kernel=3, out_channels=128, stride=1, padding='SAME',
        #     activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
        #     name='relu_conv3'
        # )
        # print('conv3 shape: %s' % conv3.get_shape())
        # pool3 = tf.nn.max_pool(
        #     conv3,
        #     ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2],
        #     padding='VALID', data_format='NCHW', name='pool3'
        # )
        # print('pool3 shape: %s' % pool3.get_shape())

        print("\nprimary layer:")
        primary_out_capsules = 32
        primary_caps_activations, conv_primary = primary_caps1d(
            pool2_dropout,
            kernel_size=3,
            out_capsules=primary_out_capsules,
            stride=1,
            padding='VALID',
            activation_length=8,
            name='primary_caps')  # (b, 32, 4, 20, 8)

        print("\nclass capsule layer:")
        class_caps_activations, class_coupling_coeffs = class_caps1d(
            primary_caps_activations,
            num_classes=num_classes,
            activation_length=16,
            routing_ites=routing_ites,
            batch_size=batch_size,
            name='class_capsules')
        # class_coupling_coeffs = tf.Print(class_coupling_coeffs, [class_coupling_coeffs], summarize=50)
        # class_caps_activations = tf.check_numerics(class_caps_activations, message="nan or inf from: class_caps_activations")
        print('class_coupling_coeffs shape: %s' %
              class_coupling_coeffs.get_shape())
        print('class_caps_activations shape: %s' %
              class_caps_activations.get_shape())

        if remake:
            remakes_flatten = _remake(class_caps_activations,
                                      image_height * image_width)
        else:
            remakes_flatten = None

        print("\ndecode layers:")
        label_logits = _decode(primary_caps_activations,
                               primary_out_capsules,
                               coupling_coeffs=class_coupling_coeffs,
                               num_classes=num_classes,
                               batch_size=batch_size,
                               pool1=pool1,
                               pool2=pool2,
                               training=training)
        # label_logits = tf.Print(label_logits, [tf.constant("label_logits"), label_logits[0]], summarize=100)
        # label_logits = tf.check_numerics(label_logits, message="nan or inf from: label_logits")

        labels2d = tf.argmax(label_logits, axis=3)
        labels2d_expanded = tf.expand_dims(labels2d, -1)
        tf.summary.image('labels', tf.cast(labels2d_expanded, tf.uint8))

    return class_caps_activations, remakes_flatten, label_logits
Example #9
0
def inference(inputs, num_classes, routing_ites=3, remake=False, training=False, name='capsnet_1d'):
    """

    :param inputs:
    :param num_classes:
    :param routing_ites:
    :param remake:
    :param name:
    :return:
    """

    with tf.variable_scope(name) as scope:
        inputs_shape = inputs.get_shape()
        batch_size = inputs_shape[0].value
        image_height = inputs_shape[2].value
        image_width = inputs_shape[3].value

        # ReLU Conv1
        # Images shape (b, 1, 24, 56) -> conv 5x5 filters, 32 output channels, strides 2 with padding, ReLU
        # nets -> (b, 256, 16, 48)
        print('inputs shape: %s' % inputs.get_shape())
        inputs = tf.check_numerics(inputs, message="nan or inf from: inputs")

        conv1 = conv2d(
            inputs,
            kernel=3, out_channels=32, stride=1, padding='SAME',
            activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
            name='relu_conv1'
        )
        print('conv1 shape: %s' % conv1.get_shape())
        pool1 = tf.nn.max_pool(
            conv1,
            ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2],
            padding='VALID', data_format='NCHW', name='pool1'
        )
        print('pool1 shape: %s' % pool1.get_shape())
        # pool1_dropout = tf.layers.dropout(pool1, 0.5, training=training, name='pool1_dropout')

        # conv2 = conv2d(
        #     pool1,
        #     kernel=3, out_channels=64, stride=1, padding='VALID',
        #     activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
        #     name='relu_conv2'
        # )
        # print('conv2 shape: %s' % conv2.get_shape())
        # pool2 = tf.nn.max_pool(
        #     conv2,
        #     ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2],
        #     padding='VALID', data_format='NCHW', name='pool2'
        # )
        # print('pool2 shape: %s' % pool2.get_shape())
        # pool2_dropout = tf.layers.dropout(pool2, 0.5, training=training, name='pool2_dropout')

        # conv3 = conv2d(
        #     pool2,
        #     kernel=3, out_channels=128, stride=1, padding='SAME',
        #     activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
        #     name='relu_conv3'
        # )
        # print('conv3 shape: %s' % conv3.get_shape())
        # pool3 = tf.nn.max_pool(
        #     conv3,
        #     ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2],
        #     padding='VALID', data_format='NCHW', name='pool3'
        # )
        # print('pool3 shape: %s' % pool3.get_shape())

        print("\nprimary layer:")
        primary_out_capsules = 64
        primary_caps_activations, _ = primary_caps1d(
            pool1,
            kernel_size=3, out_capsules=primary_out_capsules, stride=1,
            padding='VALID', activation_length=3, name='primary_caps'
        )  # (b, 32, 4, 20, 8)
        # primary_caps_activations = tf.check_numerics(primary_caps_activations,
        #                                              message="nan or inf from: primary_caps_activations")

        print("\nconvolutional capsule layer 1:")
        conv_out_capsules_1 = 32
        conv_kernel_size_1, conv_stride_1 = 4, 2
        conv_caps_activations_1, conv_coupling_coeffs_1 = conv_capsule1d(
            primary_caps_activations,
            kernel_size=conv_kernel_size_1, stride=conv_stride_1, routing_ites=3,
            in_capsules=primary_out_capsules, out_capsules=conv_out_capsules_1,
            activation_length=4, training=training, name="conv_caps_1"
        )  # (b, 32, 6, 6, 8), (b*6*6, 32*9, 32)
        conv_caps_activations_1 = tf.check_numerics(conv_caps_activations_1,
                                                  message="nan or inf from: conv_caps_activations_1")

        print("\nconvolutional capsule layer 2:")
        conv_out_capsules_2 = 24
        conv_kernel_size_2, conv_stride_2 = 3, 1
        conv_caps_activations_2, conv_coupling_coeffs_2 = conv_capsule1d(
            conv_caps_activations_1,
            kernel_size=conv_kernel_size_2, stride=conv_stride_2, routing_ites=3,
            in_capsules=conv_out_capsules_1, out_capsules=conv_out_capsules_2,
            activation_length=8, training=training, name="conv_caps_2"
        )  # (b, 32, 6, 6, 8), (b*6*6, 32*9, 32)
        conv_caps_activations_2 = tf.check_numerics(conv_caps_activations_2,
                                                  message="nan or inf from: conv_caps_activations_2")

        print("\nclass capsule layer:")
        class_caps_activations, class_coupling_coeffs = class_caps1d(
            conv_caps_activations_2,
            num_classes=num_classes, activation_length=16, routing_ites=routing_ites,
            batch_size=batch_size, training=training, name='class_capsules')
        # class_coupling_coeffs = tf.Print(class_coupling_coeffs, [class_coupling_coeffs], summarize=50)
        class_caps_activations = tf.check_numerics(class_caps_activations,
                                                   message="nan or inf from: class_caps_activations")
        print('class_coupling_coeffs shape: %s' % class_coupling_coeffs.get_shape())
        print('class_caps_activations shape: %s' % class_caps_activations.get_shape())

        if remake:
            remakes_flatten = _remake(class_caps_activations, image_height * image_width)
        else:
            remakes_flatten = None

        print("\ntraceback layer 1:")
        conv_activations_2_shape = conv_caps_activations_2.get_shape()  # (b, 32, 4, 20, 8),
        conv_height_2, conv_width_2 = conv_activations_2_shape[2].value, conv_activations_2_shape[3].value
        conv_1_cond_prob = trace_conv_cond_prob(class_coupling_coeffs, conv_coupling_coeffs_2,
                                                conv_height_2, conv_width_2, conv_kernel_size_2, conv_stride_2,
                                                conv_out_capsules_2, conv_out_capsules_1)
        print("\ntraceback layer 2:")
        conv_activations_1_shape = conv_caps_activations_1.get_shape()  # (b, 32, 4, 20, 8),
        conv_height_1, conv_width_1 = conv_activations_1_shape[2].value, conv_activations_1_shape[3].value
        primary_cond_prob = trace_conv_cond_prob(conv_1_cond_prob, conv_coupling_coeffs_1,
                                                 conv_height_1, conv_width_1, conv_kernel_size_1, conv_stride_1,
                                                 conv_out_capsules_1, primary_out_capsules)

        primary_labels = trace_labels(primary_caps_activations, primary_cond_prob, num_classes)

        print('primary_caps_labels shape: %s' % primary_labels.get_shape())
        # class_labels = tf.Print(class_labels, [tf.constant("class_labels"), class_labels])
        # class_labels = tf.check_numerics(class_labels, message="nan or inf from: class_labels")
        # primary_labels = tf.reduce_sum(caps_probs_tiled, 1)

        print("\ndeconv layers:")
        primary_conv = conv2d(
            tf.transpose(primary_labels, perm=[0, 3, 1, 2]),
            kernel=3, out_channels=256, stride=1, padding='SAME',
            activation_fn=tf.nn.relu, name='primary_conv'
        )
        print('primary_conv shape: %s' % primary_conv.get_shape())

        deconv1 = deconv(
            primary_conv,
            kernel=3, out_channels=128, stride=1, data_format='NCHW',
            activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
            name='deconv1'
        )
        print('deconv2 shape: %s' % deconv1.get_shape())
        concat1 = tf.concat([pool1, deconv1], axis=1, name='concat1')
        # dropout2 = tf.layers.dropout(concat2, 0.5, training=training, name='dropout2')
        concat1_conv = conv2d(
            concat1,
            kernel=3, out_channels=128, stride=1, padding='SAME',
            activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
            name='concat1_conv'
        )
        print('deconv1_conv shape: %s' % concat1_conv.get_shape())

        deconv2 = deconv(
            concat1_conv,
            kernel=4, out_channels=128, stride=2, data_format='NCHW',
            activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
            name='deconv2'
        )
        print('deconv2 shape: %s' % deconv2.get_shape())

        class_conv = conv2d(
            deconv2,
            kernel=3, out_channels=num_classes, stride=1, padding='VALID',
            name='class_conv'
        )
        print('class_conv shape: %s' % class_conv.get_shape())

        label_logits = tf.transpose(class_conv, perm=[0, 2, 3, 1])
        print('label_logits shape: %s' % label_logits.get_shape())
        # label_logits = tf.Print(label_logits, [tf.constant("label_logits"), label_logits])

        labels2d = tf.argmax(label_logits, axis=3)
        labels2d_expanded = tf.expand_dims(labels2d, -1)
        tf.summary.image('labels', tf.cast(labels2d_expanded, tf.uint8))

    return class_caps_activations, remakes_flatten, label_logits
def _decode(activations, capsule_num, coupling_coeffs, num_classes, batch_size,
            conv1, conv2):
    capsule_probs = tf.norm(activations,
                            axis=-1)  # # (b, 32, 4, 20, 8) -> (b, 32, 4, 20)
    caps_probs_tiled = tf.tile(tf.expand_dims(capsule_probs, -1),
                               [1, 1, 1, 1, num_classes])  # (b, 32, 4, 20, 2)
    # caps_probs_tiled = tf.check_numerics(caps_probs_tiled, message="nan or inf from: caps_probs_tiled")

    print('coupling_coeffs shape: %s' % coupling_coeffs.get_shape())
    activations_shape = activations.get_shape()
    height, width = activations_shape[2].value, activations_shape[3].value
    coupling_coeff_reshaped = tf.reshape(
        coupling_coeffs, [batch_size, capsule_num, height, width, num_classes
                          ])  # (b, 32, 4, 20, 2)
    # coupling_coeff_reshaped = tf.check_numerics(coupling_coeff_reshaped, message="nan or inf from: coupling_coeff_reshaped")

    primary_labels = tf.reduce_sum(coupling_coeff_reshaped * caps_probs_tiled,
                                   1)  # (b, 4, 20, 2)
    # class_labels = tf.Print(class_labels, [tf.constant("class_labels"), class_labels])
    # class_labels = tf.check_numerics(class_labels, message="nan or inf from: class_labels")
    # primary_labels = tf.reduce_sum(caps_probs_tiled, 1)
    # deconv1 = deconv(
    #     class_labels,
    #     kernel=3, out_channels=num_classes, stride=1,
    #     activation_fn=tf.nn.relu, name='deconv1'
    # )
    # deconv1 = tf.Print(deconv1, [tf.constant("deconv1"), deconv1])
    print('primary_labels shape: %s' % primary_labels.get_shape())
    concat1 = tf.concat(
        [tf.transpose(conv2, perm=[0, 2, 3, 1]), primary_labels],
        axis=3,
        name='concat1')
    primary_conv = conv2d(concat1,
                          kernel=3,
                          out_channels=128,
                          stride=1,
                          padding='SAME',
                          activation_fn=tf.nn.relu,
                          data_format='NHWC',
                          name='primary_conv')

    deconv2 = deconv(primary_conv,
                     kernel=8,
                     out_channels=128,
                     stride=2,
                     activation_fn=tf.nn.relu,
                     name='deconv2')
    print('deconv2 shape: %s' % deconv2.get_shape())
    concat2 = tf.concat([tf.transpose(conv1, perm=[0, 2, 3, 1]), deconv2],
                        axis=3,
                        name='concat2')
    # deconv2 = tf.Print(deconv2, [tf.constant("deconv2"), deconv2])
    deconv2_conv = conv2d(concat2,
                          kernel=5,
                          out_channels=128,
                          stride=1,
                          padding='SAME',
                          activation_fn=tf.nn.relu,
                          data_format='NHWC',
                          name='deconv2_conv')
    print('deconv2_conv shape: %s' % deconv2_conv.get_shape())

    # deconv3 = deconv(
    #     class_labels,
    #     kernel=9, out_channels=num_classes, stride=1,
    #     activation_fn=tf.nn.relu, name='deconv3'
    # )
    # deconv3 = tf.Print(deconv3, [tf.constant("deconv3"), deconv3])
    deconv3 = deconv(deconv2_conv,
                     kernel=9,
                     out_channels=num_classes,
                     stride=1,
                     activation_fn=tf.nn.relu,
                     name='deconv3')
    deconv3_conv = conv2d(deconv3,
                          kernel=3,
                          out_channels=num_classes,
                          stride=1,
                          padding='SAME',
                          activation_fn=tf.nn.relu,
                          data_format='NHWC',
                          name='deconv3_conv')

    label_logits = deconv3_conv
    print('label_logits shape: %s' % label_logits.get_shape())
    # label_logits = tf.Print(label_logits, [tf.constant("label_logits"), label_logits])
    return label_logits
def inference(inputs,
              num_classes,
              routing_ites=3,
              remake=False,
              training=False,
              name='capsnet_1d'):
    """

    :param inputs:
    :param num_classes:
    :param routing_ites:
    :param remake:
    :param name:
    :return:
    """

    with tf.variable_scope(name) as scope:
        inputs_shape = inputs.get_shape()
        batch_size = inputs_shape[0].value
        image_height = inputs_shape[2].value
        image_width = inputs_shape[3].value

        # ReLU Conv1
        # Images shape (b, 1, 24, 56) -> conv 5x5 filters, 32 output channels, strides 2 with padding, ReLU
        # nets -> (b, 256, 16, 48)
        print('inputs shape: %s' % inputs.get_shape())
        inputs = tf.check_numerics(inputs, message="nan or inf from: inputs")

        print("\nconv1 layer:")
        conv1 = conv2d(inputs,
                       kernel=9,
                       out_channels=256,
                       stride=1,
                       padding='VALID',
                       activation_fn=tf.nn.relu,
                       name='relu_conv1')
        # conv1 = tf.check_numerics(conv1, message="nan or inf from: conv1")
        print('conv1 shape: %s' % conv1.get_shape())

        # print("\nconv2 layer:")
        # conv2 = conv2d(
        #     conv1,
        #     kernel=5, out_channels=256, stride=1, padding='VALID',
        #     activation_fn=tf.nn.relu, name='relu_conv2'
        # )
        # # conv2 = tf.check_numerics(conv2, message="nan or inf from: conv2")
        # print('conv2 shape: %s' % conv2.get_shape())

        # PrimaryCaps
        # (b, 256, 16, 48) -> capsule 1x1 filter, 32 output capsule, strides 1 without padding
        # nets -> activations (?, 14, 14, 32))
        print("\nprimary layer:")
        primary_out_capsules = 24
        primary_caps_activations, conv2 = primary_caps1d(
            conv1,
            kernel_size=7,
            out_capsules=primary_out_capsules,
            stride=2,
            padding='VALID',
            activation_length=8,
            name='primary_caps')  # (b, 32, 4, 20, 8)

        # (b, 32, 4, 20, 8) -> # (b, 32*4*20, 2*64)
        print("\nclass capsule layer:")
        class_caps_activations, class_coupling_coeffs = class_caps1d(
            primary_caps_activations,
            num_classes=num_classes,
            activation_length=16,
            routing_ites=routing_ites,
            batch_size=batch_size,
            name='class_capsules')
        # class_coupling_coeffs = tf.Print(class_coupling_coeffs, [class_coupling_coeffs], summarize=50)
        # class_caps_activations = tf.check_numerics(class_caps_activations, message="nan or inf from: class_caps_activations")
        print('class_coupling_coeffs shape: %s' %
              class_coupling_coeffs.get_shape())
        print('class_caps_activations shape: %s' %
              class_caps_activations.get_shape())

        if remake:
            remakes_flatten = _remake(class_caps_activations,
                                      image_height * image_width)
        else:
            remakes_flatten = None

        print("\ndecode layers:")
        label_logits = _decode(primary_caps_activations,
                               primary_out_capsules,
                               coupling_coeffs=class_coupling_coeffs,
                               num_classes=num_classes,
                               batch_size=batch_size,
                               conv1=conv1,
                               conv2=conv2)
        # label_logits = tf.Print(label_logits, [tf.constant("label_logits"), label_logits[0]], summarize=100)
        # label_logits = tf.check_numerics(label_logits, message="nan or inf from: label_logits")

        labels2d = tf.argmax(label_logits, axis=3)
        labels2d_expanded = tf.expand_dims(labels2d, -1)
        tf.summary.image('labels', tf.cast(labels2d_expanded, tf.uint8))

    return class_caps_activations, remakes_flatten, label_logits
Example #12
0
def inference(inputs,
              num_classes,
              feature_scale=2,
              training=True,
              name='unet'):
    with tf.variable_scope(name) as scope:
        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / feature_scale) for x in filters]

        combine_conv1_1 = _combine_conv(inputs,
                                        '1_1',
                                        training,
                                        out_channels=filters[0])
        combine_conv1_2 = _combine_conv(combine_conv1_1,
                                        '1_2',
                                        training,
                                        out_channels=filters[0])
        pool1 = tf.layers.max_pooling2d(inputs=combine_conv1_2,
                                        pool_size=[2, 2],
                                        strides=[2, 2],
                                        data_format="channels_first",
                                        name='pool1')
        print('pool1 shape: %s\n' % (pool1.get_shape()))
        combine_conv2_1 = _combine_conv(pool1,
                                        '2_1',
                                        training,
                                        out_channels=filters[1])
        combine_conv2_2 = _combine_conv(combine_conv2_1,
                                        '2_2',
                                        training,
                                        out_channels=filters[1])
        pool2 = tf.layers.max_pooling2d(inputs=combine_conv2_2,
                                        pool_size=[2, 2],
                                        strides=[2, 2],
                                        data_format="channels_first",
                                        name='pool2')
        print('pool2 shape: %s\n' % (pool2.get_shape()))
        combine_conv3_1 = _combine_conv(pool2,
                                        '3_1',
                                        training,
                                        out_channels=filters[2])
        combine_conv3_2 = _combine_conv(combine_conv3_1,
                                        '3_2',
                                        training,
                                        out_channels=filters[2])
        pool3 = tf.layers.max_pooling2d(inputs=combine_conv3_2,
                                        pool_size=[2, 2],
                                        strides=[2, 2],
                                        data_format="channels_first",
                                        name='pool3')
        print('pool3 shape: %s\n' % (pool3.get_shape()))
        combine_conv4_1 = _combine_conv(pool3,
                                        '4_1',
                                        training,
                                        out_channels=filters[3])
        combine_conv4_2 = _combine_conv(combine_conv4_1,
                                        '4_2',
                                        training,
                                        out_channels=filters[3])
        # combine_conv4_2 = tf.nn.dropout(combine_conv4_2, 0.5)
        pool4 = tf.layers.max_pooling2d(inputs=combine_conv4_2,
                                        pool_size=[2, 2],
                                        strides=[2, 2],
                                        data_format="channels_first",
                                        name='pool4')
        print('pool4 shape: %s\n' % (pool4.get_shape()))

        center = _combine_conv(pool4, '5', training, out_channels=filters[4])
        # center = tf.nn.dropout(center, 0.5)

        combine_deconv4 = _combine_deconv(center,
                                          '4',
                                          training,
                                          combine_conv4_2,
                                          deconv_out_channels=filters[3])
        combine_deconv3 = _combine_deconv(combine_deconv4,
                                          '3',
                                          training,
                                          combine_conv3_2,
                                          deconv_out_channels=filters[2])
        combine_deconv2 = _combine_deconv(combine_deconv3,
                                          '2',
                                          training,
                                          combine_conv2_2,
                                          deconv_out_channels=filters[1])
        combine_deconv1 = _combine_deconv(combine_deconv2,
                                          '1',
                                          training,
                                          combine_conv1_2,
                                          deconv_out_channels=filters[0])

        final_deconv = deconv(combine_deconv1,
                              kernel=5,
                              out_channels=num_classes,
                              stride=1,
                              data_format='NCHW',
                              name='final_deconv')
        final = conv2d(final_deconv,
                       kernel=1,
                       out_channels=num_classes,
                       stride=1,
                       padding='SAME',
                       name='final_conv')

        label_logits = tf.transpose(final, perm=[0, 2, 3, 1])
        print('label_logits shape: %s' % label_logits.get_shape())
        return label_logits
Example #13
0
def inference(inputs, num_classes, feature_scale=1, training=True, name='unet'):
    with tf.variable_scope(name) as scope:
        filters = [64, 128, 256, 512, 512]
        filters = [int(x / feature_scale) for x in filters]

        padded_input = tf.pad(inputs, tf.constant([[0, 0], [0, 0], [100, 100], [100, 100]]), "CONSTANT")

        combine_conv1_1 = _combine_conv(padded_input, '1_1', training, in_channels=3, out_channels=filters[0])
        combine_conv1_2 = _combine_conv(combine_conv1_1, '1_2', training,
                                        in_channels=filters[0], out_channels=filters[0])
        pool1 = tf.layers.max_pooling2d(inputs=combine_conv1_2, pool_size=[2, 2], strides=[2, 2],
                                        data_format="channels_first", name='pool1')
        print('pool1 shape: %s\n' % (pool1.get_shape()))
        combine_conv2_1 = _combine_conv(pool1, '2_1', training, in_channels=filters[0], out_channels=filters[1])
        combine_conv2_2 = _combine_conv(combine_conv2_1, '2_2', training,
                                        in_channels=filters[1], out_channels=filters[1])
        pool2 = tf.layers.max_pooling2d(inputs=combine_conv2_2, pool_size=[2, 2], strides=[2, 2],
                                        data_format="channels_first", name='pool2')
        print('pool2 shape: %s\n' % (pool2.get_shape()))
        combine_conv3_1 = _combine_conv(pool2, '3_1', training,
                                        in_channels=filters[1], out_channels=filters[2])
        combine_conv3_2 = _combine_conv(combine_conv3_1, '3_2', training,
                                        in_channels=filters[2], out_channels=filters[2])
        combine_conv3_3 = _combine_conv(combine_conv3_2, '3_3', training,
                                        in_channels=filters[2], out_channels=filters[2])
        pool3 = tf.layers.max_pooling2d(inputs=combine_conv3_3, pool_size=[2, 2], strides=[2, 2],
                                        data_format="channels_first", name='pool3')
        print('pool3 shape: %s\n' % (pool3.get_shape()))
        combine_conv4_1 = _combine_conv(pool3, '4_1', training, in_channels=filters[2], out_channels=filters[3])
        combine_conv4_2 = _combine_conv(combine_conv4_1, '4_2', training,
                                        in_channels=filters[3], out_channels=filters[3])
        combine_conv4_3 = _combine_conv(combine_conv4_2, '4_3', training,
                                        in_channels=filters[3], out_channels=filters[3])
        pool4 = tf.layers.max_pooling2d(inputs=combine_conv4_3, pool_size=[2, 2], strides=[2, 2],
                                        data_format="channels_first", name='pool4')
        print('pool4 shape: %s\n' % (pool4.get_shape()))
        combine_conv5_1 = _combine_conv(pool4, '5_1', training, in_channels=filters[3], out_channels=filters[4])
        combine_conv5_2 = _combine_conv(combine_conv5_1, '5_2', training,
                                        in_channels=filters[4], out_channels=filters[4])
        combine_conv5_3 = _combine_conv(combine_conv5_2, '5_3', training,
                                        in_channels=filters[4], out_channels=filters[4])
        pool5 = tf.layers.max_pooling2d(inputs=combine_conv5_3, pool_size=[2, 2], strides=[2, 2],
                                        data_format="channels_first", name='pool5')
        print('pool5 shape: %s\n' % (pool5.get_shape()))

        combine_conv6 = _combine_conv(pool5, '6', training, kernel=7,
                                      in_channels=filters[4], out_channels=4096, padding='VALID')
        drop6 = tf.nn.dropout(combine_conv6, 0.5)
        combine_conv7 = _combine_conv(drop6, '7', training, kernel=1,
                                      in_channels=4096, out_channels=4096, padding='VALID')
        drop7 = tf.nn.dropout(combine_conv7, 0.5)
        score_fr = conv2d(
            drop7, kernel=1,
            out_channels=num_classes, stride=1,
            padding='VALID', name='score_fr',
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005)
        )
        kernel_upscore2 = 4
        upscore2 = deconv(
            score_fr,
            kernel=kernel_upscore2, out_channels=num_classes, stride=2,
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005),
            weights_initializer=_deconv_intializer(kernel_upscore2),
            data_format='NCHW', activation_fn=None, name='upscore2'
        )
        print('upscore2 shape: %s\n' % (upscore2.get_shape()))

        score_pool4 = conv2d(
            pool4, kernel=1,
            out_channels=num_classes, stride=1,
            padding='VALID', name='score_pool4',
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005)
        )
        print('score_pool4 shape: %s' % (score_pool4.get_shape()))
        score_pool4c = _crop(score_pool4, upscore2)  # need crop here
        print('score_pool4c shape: %s' % (score_pool4c.get_shape()))
        fuse_pool4 = upscore2 + score_pool4c
        kernel_upscore_pool4 = 4
        upscore_pool4 = deconv(
            fuse_pool4,
            kernel=kernel_upscore_pool4, out_channels=num_classes, stride=2,
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005),
            weights_initializer=_deconv_intializer(kernel_upscore_pool4),
            data_format='NCHW', activation_fn=None, name='upscore_pool4'
        )
        print('upscore_pool4 shape: %s\n' % (upscore_pool4.get_shape()))

        score_pool3 = conv2d(
            pool3, kernel=1,
            out_channels=num_classes, stride=1,
            padding='VALID', name='score_pool3',
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005)
        )
        score_pool3c = _crop(score_pool3, upscore_pool4)  # need crop here
        print('score_pool3c shape: %s' % (score_pool3c.get_shape()))
        fuse_pool3 = upscore_pool4 + score_pool3c
        kernel_upscore_pool8 = 16
        upscore_pool8 = deconv(
            fuse_pool3,
            kernel=kernel_upscore_pool8, out_channels=num_classes, stride=8,
            data_format='NCHW', activation_fn=None, name='upscore_pool8',
            weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.0005),
            weights_initializer=_deconv_intializer(kernel_upscore_pool8)
        )
        print('upscore_pool8 shape: %s\n' % (upscore_pool8.get_shape()))

        score = _crop(upscore_pool8, inputs)  # need crop here
        print('score shape: %s\n' % (score.get_shape()))
        label_logits = tf.transpose(score, perm=[0, 2, 3, 1])
        print('label_logits shape: %s' % label_logits.get_shape())
        return label_logits
def inference(inputs,
              num_classes,
              training=True,
              routing_ites=3,
              name='unet_pascal'):
    with tf.variable_scope(name) as scope:
        inputs_shape = inputs.get_shape()
        batch_size = inputs_shape[0].value

        combine_conv1 = _combine_conv(inputs, '1', training, out_channels=16)
        combine_conv2 = _combine_conv(combine_conv1,
                                      '2',
                                      training,
                                      out_channels=32)
        combine_conv3 = _combine_conv(combine_conv2,
                                      '3',
                                      training,
                                      out_channels=64)
        combine_conv4 = _combine_conv(combine_conv3,
                                      '4',
                                      training,
                                      out_channels=128)

        print("\nprimary layer:")
        primary_out_capsules = 32
        primary_caps_activations, primary_conv = primary_caps1d(
            combine_conv4,
            kernel_size=5,
            out_capsules=primary_out_capsules,
            stride=2,
            padding='VALID',
            activation_length=8,
            name='primary_caps')
        print('primary_conv shape: %s' % primary_conv.get_shape())

        print("\nclass capsule layer:")
        class_caps_activations, class_coupling_coeffs = class_caps1d(
            primary_caps_activations,
            num_classes=num_classes,
            activation_length=16,
            routing_ites=routing_ites,
            batch_size=batch_size,
            name='class_capsules')
        print('class_coupling_coeffs shape: %s' %
              class_coupling_coeffs.get_shape())
        print('class_caps_activations shape: %s' %
              class_caps_activations.get_shape())

        capsule_probs = tf.norm(primary_caps_activations, axis=-1)
        caps_probs_tiled = tf.tile(tf.expand_dims(capsule_probs, -1),
                                   [1, 1, 1, 1, num_classes])

        primary_activations_shape = primary_caps_activations.get_shape()
        height, width = primary_activations_shape[
            2].value, primary_activations_shape[3].value
        coupling_coeff_reshaped = tf.reshape(
            class_coupling_coeffs,
            [batch_size, primary_out_capsules, height, width, num_classes])

        primary_labels = tf.reduce_sum(
            coupling_coeff_reshaped * caps_probs_tiled, 1)
        print('\nprimary_labels shape: %s\n' % primary_labels.get_shape())
        concat1 = tf.concat(
            [primary_conv,
             tf.transpose(primary_labels, perm=[0, 3, 1, 2])],
            axis=1,
            name='concat1')
        primary_label_conv = conv2d(concat1,
                                    kernel=3,
                                    out_channels=128,
                                    stride=1,
                                    padding='SAME',
                                    activation_fn=tf.nn.relu,
                                    data_format='NCHW',
                                    name='primary_label_conv')

        combine_deconv5 = _combine_deconv(primary_label_conv,
                                          '5',
                                          training,
                                          combine_conv4,
                                          deconv_out_channels=256)
        combine_deconv4 = _combine_deconv(combine_deconv5,
                                          '4',
                                          training,
                                          combine_conv3,
                                          kernel_deconv=[7, 7],
                                          deconv_out_channels=128)
        combine_deconv3 = _combine_deconv(combine_deconv4,
                                          '3',
                                          training,
                                          combine_conv2,
                                          kernel_deconv=[6, 6],
                                          deconv_out_channels=64)
        combine_deconv2 = _combine_deconv(combine_deconv3,
                                          '2',
                                          training,
                                          combine_conv1,
                                          kernel_deconv=[6, 7],
                                          deconv_out_channels=64)
        combine_deconv1 = _combine_deconv(combine_deconv2,
                                          '1',
                                          training,
                                          None,
                                          kernel_deconv=[6, 7],
                                          deconv_out_channels=32)

        conv = conv2d(combine_deconv1,
                      kernel=3,
                      out_channels=num_classes,
                      stride=1,
                      padding='SAME',
                      name='label_conv')

        label_logits = tf.transpose(conv, perm=[0, 2, 3, 1])
        # label_logits = tf.check_numerics(label_logits, message="nan or inf from: label_logits")
        print('label_logits shape: %s' % label_logits.get_shape())
        return class_caps_activations, None, label_logits