Exemple #1
0
def primary_caps(conv, conv_dim, output_dim, out_atoms):
    """First Capsule layer where activation is calculated via sigmoid+conv."""
    with tf.variable_scope('conv_capsule1'):
        w_kernel = utils.weight_variable(
            shape=[1, 1, conv_dim, (out_atoms) * output_dim], stddev=0.5)
        # w_kernel = tf.clip_by_norm(w_kernel, 1.0, axes=[2])
        with tf.variable_scope('conv_capsule1_act'):
            a_kernel = utils.weight_variable(
                shape=[1, 1, conv_dim, output_dim], stddev=3.0)
        kernel = tf.concat((w_kernel, a_kernel), axis=3)
        conv_caps = tf.nn.conv2d(conv,
                                 kernel, [1, 1, 1, 1],
                                 padding='SAME',
                                 data_format='NCHW')
        _, _, c_height, c_width = conv_caps.get_shape()
        conv_shape = tf.shape(conv_caps)
        # conv_reshaped: [x, 128, out, out_at, c3, c4]
        conv_reshaped = tf.reshape(conv_caps, [
            conv_shape[0], output_dim, out_atoms + 1, conv_shape[2],
            conv_shape[3]
        ])
        conv_reshaped.set_shape(
            (None, output_dim, out_atoms + 1, c_height.value, c_width.value))
        conv_caps_center, conv_caps_logit = tf.split(conv_reshaped,
                                                     [out_atoms, 1],
                                                     axis=2)
        conv_caps_activation = tf.sigmoid(conv_caps_logit - 1.0)
    return conv_caps_activation, conv_caps_center
def primary_caps(conv, conv_dim, output_dim, out_atoms):
    """First Capsule layer where activation is calculated via sigmoid+conv."""
    with tf.variable_scope('conv_capsule1'):
        w_kernel = utils.weight_variable(
            shape=[1, 1, conv_dim, (out_atoms) * output_dim], stddev=0.5)
        # w_kernel = tf.clip_by_norm(w_kernel, 1.0, axes=[2])
        with tf.variable_scope('conv_capsule1_act'):
            a_kernel = utils.weight_variable(
                shape=[1, 1, conv_dim, output_dim], stddev=3.0)
        kernel = tf.concat((w_kernel, a_kernel), axis=3)
        if FLAGS.cpu_way:
            conv = tf.transpose(conv, [0, 2, 3, 1])
            data_format = 'NHWC'
        else:
            data_format = 'NCHW'
        conv_caps = tf.nn.conv2d(conv,
                                 kernel, [1, 1, 1, 1],
                                 padding='SAME',
                                 data_format=data_format)
        if FLAGS.cpu_way:
            conv_caps = tf.transpose(conv_caps, [0, 3, 1, 2])
        _, _, c_height, c_width = conv_caps.get_shape()
        conv_shape = tf.shape(conv_caps)
        conv_caps_center, conv_caps_logit = tf.split(
            conv_caps, [out_atoms * output_dim, output_dim], axis=1)
        # conv_reshaped: [x, 128, out, out_at, c3, c4]
        center_reshaped = tf.reshape(conv_caps_center, [
            conv_shape[0], output_dim, out_atoms, conv_shape[2], conv_shape[3]
        ])
        center_reshaped.set_shape(
            (None, output_dim, out_atoms, c_height, c_width))
        logit_reshaped = tf.reshape(
            conv_caps_logit,
            [conv_shape[0], output_dim, 1, conv_shape[2], conv_shape[3]])
        logit_reshaped.set_shape((None, output_dim, 1, c_height, c_width))
        conv_caps_activation = tf.sigmoid(logit_reshaped - 1.0)
    return conv_caps_activation, center_reshaped
def conv_capsule_mat(input_tensor,
                     input_activation,
                     input_dim,
                     output_dim,
                     layer_name,
                     num_routing=3,
                     num_in_atoms=3,
                     num_out_atoms=3,
                     stride=2,
                     kernel_size=5,
                     min_var=0.0005,
                     final_beta=1.0):
    """Convolutional Capsule layer with Pose Matrices."""
    print('caps conv stride: {}'.format(stride))
    in_atom_sq = num_in_atoms * num_in_atoms
    with tf.variable_scope(layer_name):
        input_shape = tf.shape(input_tensor)
        _, _, _, in_height, in_width = input_tensor.get_shape()
        # This Variable will hold the state of the weights for the layer
        kernel = utils.weight_variable(shape=[
            input_dim, kernel_size, kernel_size, num_in_atoms,
            output_dim * num_out_atoms
        ],
                                       stddev=0.3)
        # kernel = tf.clip_by_norm(kernel, 3.0, axes=[1, 2, 3])
        activation_biases = utils.bias_variable(
            [1, 1, output_dim, 1, 1, 1, 1, 1],
            init_value=0.5,
            name='activation_biases')
        sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1],
                                           init_value=.5,
                                           name='sigma_biases')
        with tf.name_scope('conv'):
            print('convi;')
            # input_tensor: [x,128,8, c1,c2] -> [x*128,8, c1,c2]
            print(input_tensor.get_shape())
            input_tensor_reshaped = tf.reshape(input_tensor, [
                input_shape[0] * input_dim * in_atom_sq, input_shape[3],
                input_shape[4], 1
            ])
            input_tensor_reshaped.set_shape((None, input_tensor.get_shape()[3],
                                             input_tensor.get_shape()[4], 1))
            input_act_reshaped = tf.reshape(input_activation, [
                input_shape[0] * input_dim, input_shape[3], input_shape[4], 1
            ])
            input_act_reshaped.set_shape((None, input_tensor.get_shape()[3],
                                          input_tensor.get_shape()[4], 1))
            print(input_tensor_reshaped.get_shape())
            # conv: [x*128,out*out_at, c3,c4]
            conv_patches = tf.extract_image_patches(
                images=input_tensor_reshaped,
                ksizes=[1, kernel_size, kernel_size, 1],
                strides=[1, stride, stride, 1],
                rates=[1, 1, 1, 1],
                padding='VALID',
            )
            act_patches = tf.extract_image_patches(
                images=input_act_reshaped,
                ksizes=[1, kernel_size, kernel_size, 1],
                strides=[1, stride, stride, 1],
                rates=[1, 1, 1, 1],
                padding='VALID',
            )
            o_height = (in_height - kernel_size) // stride + 1
            o_width = (in_width - kernel_size) // stride + 1
            patches = tf.reshape(conv_patches,
                                 (input_shape[0], input_dim, in_atom_sq,
                                  o_height, o_width, kernel_size, kernel_size))
            patches.set_shape((None, input_dim, in_atom_sq, o_height, o_width,
                               kernel_size, kernel_size))
            patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2])
            patch_split = tf.reshape(
                patch_trans,
                (input_dim, kernel_size, kernel_size, input_shape[0] *
                 o_height * o_width * num_in_atoms, num_in_atoms))
            patch_split.set_shape(
                (input_dim, kernel_size, kernel_size, None, num_in_atoms))
            a_patches = tf.reshape(act_patches,
                                   (input_shape[0], input_dim, 1, 1, o_height,
                                    o_width, kernel_size, kernel_size))
            a_patches.set_shape((None, input_dim, 1, 1, o_height, o_width,
                                 kernel_size, kernel_size))
            with tf.name_scope('input_act'):
                utils.activation_summary(
                    tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(a_patches,
                                                              axis=1),
                                                axis=-1),
                                  axis=-1))
            with tf.name_scope('Wx'):
                wx = tf.matmul(patch_split, kernel)
                wx = tf.reshape(wx, (input_dim, kernel_size, kernel_size,
                                     input_shape[0], o_height, o_width,
                                     num_in_atoms * num_out_atoms, output_dim))
                wx.set_shape(
                    (input_dim, kernel_size, kernel_size, None, o_height,
                     o_width, num_in_atoms * num_out_atoms, output_dim))
                wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2])
                utils.activation_summary(wx)

        with tf.name_scope('routing'):
            # Routing
            # logits: [x, 128, 10, c3, c4]
            logit_shape = [
                input_dim, output_dim, 1, o_height, o_width, kernel_size,
                kernel_size
            ]
            activation, center = update_conv_routing(
                wx=wx,
                input_activation=a_patches,
                activation_biases=activation_biases,
                sigma_biases=sigma_biases,
                logit_shape=logit_shape,
                num_out_atoms=num_out_atoms * num_out_atoms,
                input_dim=input_dim,
                num_routing=num_routing,
                output_dim=output_dim,
                min_var=min_var,
                final_beta=final_beta,
            )
            # activations: [x, 10, 8, c3, c4]

        out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7])
        out_center = tf.squeeze(center, axis=[1, 6, 7])
        with tf.name_scope('center'):
            utils.activation_summary(out_center)
        return tf.sigmoid(out_activation), out_center
def connector_capsule_mat(input_tensor,
                          position_grid,
                          input_activation,
                          input_dim,
                          output_dim,
                          layer_name,
                          num_routing=3,
                          num_in_atoms=3,
                          num_out_atoms=3,
                          leaky=False,
                          final_beta=1.0,
                          min_var=0.0005):
    """Final Capsule Layer with Pose Matrices and Shared connections."""
    # One weight tensor for each capsule of the layer bellow: w: [8*128, 8*10]
    with tf.variable_scope(layer_name):
        # This Variable will hold the state of the weights for the layer
        with tf.name_scope('input_center_connector'):
            utils.activation_summary(input_tensor)
        weights = utils.weight_variable(
            [input_dim, num_out_atoms, output_dim * num_out_atoms],
            stddev=0.01)
        # weights = tf.clip_by_norm(weights, 1.0, axes=[1])
        activation_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1],
                                                init_value=1.0,
                                                name='activation_biases')
        sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1],
                                           init_value=2.0,
                                           name='sigma_biases')

        with tf.name_scope('Wx_plus_b'):
            # input_tensor: [x, 128, 8, h, w]
            input_shape = tf.shape(input_tensor)
            input_trans = tf.transpose(input_tensor, [1, 0, 3, 4, 2])
            input_share = tf.reshape(input_trans,
                                     [input_dim, -1, num_in_atoms])
            # input_expanded: [x, 128, 8, 1]
            wx_share = tf.matmul(input_share, weights)
            # sqr_num_out_atoms = num_out_atoms
            num_out_atoms *= num_out_atoms
            wx_trans = tf.reshape(wx_share, [
                input_dim, input_shape[0], input_shape[3], input_shape[4],
                num_out_atoms, output_dim
            ])
            wx_trans.set_shape(
                (input_dim, None, input_tensor.get_shape()[3],
                 input_tensor.get_shape()[4], num_out_atoms, output_dim))
            h, w, _ = position_grid.get_shape()
            height = h
            width = w
            # t_pose = tf.transpose(position_grid, [2, 0, 1])
            # t_pose_exp = tf.scatter_nd([[sqr_num_out_atoms -1],
            #   [2 * sqr_num_out_atoms - 1]], t_pose, [num_out_atoms, height, width])
            # pose_g_exp = tf.transpose(t_pose_exp, [1, 2, 0])
            zero_grid = tf.zeros([height, width, num_out_atoms - 2])
            pose_g_exp = tf.concat([position_grid, zero_grid], axis=2)
            pose_g = tf.expand_dims(
                tf.expand_dims(tf.expand_dims(pose_g_exp, -1), 0), 0)
            wx_posed = wx_trans + pose_g
            wx_posed_t = tf.transpose(wx_posed, [1, 0, 2, 3, 5, 4])

            # Wx_reshaped: [x, 128, 10, 8]
            wx = tf.reshape(wx_posed_t, [
                -1, input_dim * height * width, output_dim, num_out_atoms, 1, 1
            ])
        with tf.name_scope('routing'):
            # Routing
            # logits: [x, 128, 10]
            logit_shape = [input_dim * height * width, output_dim, 1, 1, 1]
            for _ in range(4):
                input_activation = tf.expand_dims(input_activation, axis=-1)
            activation, center = update_em_routing(
                wx=wx,
                input_activation=input_activation,
                activation_biases=activation_biases,
                sigma_biases=sigma_biases,
                logit_shape=logit_shape,
                num_out_atoms=num_out_atoms,
                num_routing=num_routing,
                output_dim=output_dim,
                leaky=leaky,
                final_beta=final_beta / 4,
                min_var=min_var,
            )
        out_activation = tf.squeeze(activation, axis=[1, 3, 4, 5])
        out_center = tf.squeeze(center, axis=[1, 4, 5])
        return tf.sigmoid(out_activation), out_center
def conv_capsule_mat_fast(
    input_tensor,
    input_activation,
    input_dim,
    output_dim,
    layer_name,
    num_routing=3,
    num_in_atoms=3,
    num_out_atoms=3,
    stride=2,
    kernel_size=5,
    min_var=0.0005,
    final_beta=1.0,
):
    """Convolutional Capsule layer with fast EM routing.

  Args:
    input_tensor: The input capsule features.
    input_activation: The input capsule activations.
    input_dim: Number of input capsule types.
    output_dim: Number of output capsule types.
    layer_name: Name of this layer, e.g. conv_capsule1
    num_routing: Number of routing iterations.
    num_in_atoms: Number of features in each of the input capsules.
    num_out_atoms: Number of features in each of the output capsules.
    stride: Stride of the convolution.
    kernel_size: kernel size of the convolution.
    min_var: Minimum varience for each capsule to avoid NaNs.
    final_beta: beta for making the routing factors sharp.

  Returns:
    The final capsule center and activations.
  """
    tf.logging.info('conv_capsule_mat %s', layer_name)
    tf.logging.info('input_shape %s', input_tensor.shape.as_list())
    in_atom_sq = num_in_atoms * num_in_atoms
    with tf.variable_scope(layer_name):
        # This should be fully defined...
        # input_shape = tf.shape(input_tensor)
        input_shape = input_tensor.shape.as_list()
        batch, _, _, in_height, in_width = input_shape
        o_height = (in_height - kernel_size) // stride + 1
        o_width = (in_width - kernel_size) // stride + 1

        # This Variable will hold the state of the weights for the layer.
        kernel = utils.weight_variable(shape=[
            input_dim, kernel_size, kernel_size, num_in_atoms,
            output_dim * num_out_atoms
        ],
                                       stddev=0.1)
        activation_biases = utils.bias_variable(
            [1, 1, output_dim, 1, 1, 1, 1, 1],
            init_value=0.2,
            name='activation_biases')
        sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1],
                                           init_value=.5,
                                           name='sigma_biases')

        with utils.maybe_jit_scope(), tf.name_scope('conv'):
            input_tensor_reshaped = tf.reshape(
                input_tensor,
                [batch * input_dim * in_atom_sq, in_height, in_width, 1])
            input_act_reshaped = tf.reshape(
                input_activation, [batch * input_dim, in_height, in_width, 1])

            conv_patches = utils.kernel_tile(input_tensor_reshaped,
                                             kernel_size, stride)
            act_patches = utils.kernel_tile(input_act_reshaped, kernel_size,
                                            stride)

            patches = tf.reshape(conv_patches,
                                 (batch, input_dim, in_atom_sq, o_height,
                                  o_width, kernel_size, kernel_size))
            patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2])
            patch_split = tf.reshape(
                patch_trans,
                (input_dim, kernel_size, kernel_size,
                 batch * o_height * o_width * num_in_atoms, num_in_atoms),
                name='patch_split')
            a_patches = tf.reshape(act_patches,
                                   (batch, input_dim, 1, 1, o_height, o_width,
                                    kernel_size, kernel_size),
                                   name='a_patches')

        # Recompute Wx on backprop to save memory (perhaps redo patches as well?)
        # @tf.contrib.layers.recompute_grad
        def compute_wx(patch_split, kernel, is_recomputing=False):
            tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing)
            with utils.maybe_jit_scope(), tf.name_scope('wx'):
                wx = tf.matmul(patch_split, kernel)
                wx = tf.reshape(
                    wx, (input_dim, kernel_size, kernel_size, batch, o_height,
                         o_width, num_in_atoms * num_out_atoms, output_dim))
                wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2])
            return wx

        wx = compute_wx(patch_split, kernel.value())

        with utils.maybe_jit_scope():
            # Routing
            logit_shape = [
                input_dim, output_dim, 1, o_height, o_width, kernel_size,
                kernel_size
            ]
            tf.logging.info('logit_shape: %s', logit_shape)
            activation, center = update_conv_routing_fast(
                wx=wx,
                input_activation=a_patches,
                activation_biases=activation_biases,
                sigma_biases=sigma_biases,
                logit_shape=logit_shape,
                num_out_atoms=num_out_atoms * num_out_atoms,
                input_dim=input_dim,
                num_routing=num_routing,
                output_dim=output_dim,
                min_var=min_var,
                final_beta=4 * final_beta,
                stride=stride,
                layer_name=layer_name,
            )

        with utils.maybe_jit_scope():
            out_activation = tf.squeeze(activation,
                                        axis=[1, 3, 6, 7],
                                        name='out_activation')
            out_center = tf.squeeze(center, axis=[1, 6, 7], name='out_center')
            out_activation = tf.sigmoid(out_activation)

        with tf.name_scope('center'):
            utils.activation_summary(out_center)

        return out_activation, out_center
def add_convs(features):
    """Stack Convolution layers."""
    image_dim = features['height']
    image_depth = features['depth']
    image = features['images']
    position_grid = tf.reshape(
        tf.constant(np.mgrid[(-image_dim // 2):((image_dim + 1) // 2),
                             (-image_dim // 2):((image_dim + 1) // 2)],
                    dtype=tf.float32) / 100.0, (1, 2, image_dim, image_dim))
    if FLAGS.verbose_image:
        with tf.name_scope('input_reshape'):
            image_shaped_input = tf.reshape(
                image, [-1, image_dim, image_dim, image_depth])
            tf.summary.image('input', image_shaped_input, 10)

    with tf.variable_scope('conv1') as scope:
        kernel = utils.weight_variable(shape=[
            FLAGS.kernel_size, FLAGS.kernel_size, image_depth,
            FLAGS.num_start_conv
        ],
                                       stddev=5e-2)

        image_reshape = tf.reshape(image,
                                   [-1, image_depth, image_dim, image_dim])
        if FLAGS.cpu_way:
            image_reshape = tf.transpose(image_reshape, [0, 2, 3, 1])
            data_format = 'NHWC'
            strides = [1, FLAGS.stride_1, FLAGS.stride_1, 1]
        else:
            data_format = 'NCHW'
            strides = [1, 1, FLAGS.stride_1, FLAGS.stride_1]
        conv = tf.nn.conv2d(image_reshape,
                            kernel,
                            strides,
                            padding=FLAGS.padding,
                            data_format=data_format)
        biases = utils.bias_variable([FLAGS.num_start_conv])
        pre_activation = tf.nn.bias_add(conv, biases, data_format=data_format)
        if FLAGS.cpu_way:
            pre_activation = tf.transpose(pre_activation, [0, 3, 1, 2])
        position_grid = conv_pos(position_grid, FLAGS.kernel_size,
                                 FLAGS.stride_1, FLAGS.padding)
        conv1 = tf.nn.relu(pre_activation, name=scope.name)
        if FLAGS.verbose:
            tf.summary.histogram('activation', conv1)
        if FLAGS.pooling:
            pool1 = tf.nn.max_pool2d(conv1,
                                     ksize=2,
                                     strides=2,
                                     data_format='NCHW',
                                     padding='SAME')
            convs = [pool1]
        else:
            convs = [conv1]

    conv_outputs = [FLAGS.num_start_conv]

    for i in range(int(FLAGS.extra_conv)):
        conv_outputs += [int(FLAGS.conv_dims.split(',')[i])]
        with tf.variable_scope('conv{}'.format(i + 2)) as scope:
            kernel = utils.weight_variable(shape=[
                int(FLAGS.conv_kernels.split(',')[i]),
                int(FLAGS.conv_kernels.split(',')[i]), conv_outputs[i],
                conv_outputs[i + 1]
            ],
                                           stddev=5e-2)
            conv = tf.nn.conv2d(convs[i],
                                kernel, [
                                    1, 1,
                                    int(FLAGS.conv_strides.split(',')[i]),
                                    int(FLAGS.conv_strides.split(',')[i])
                                ],
                                padding=FLAGS.padding,
                                data_format='NCHW')
            position_grid = conv_pos(position_grid,
                                     int(FLAGS.conv_kernels.split(',')[i]),
                                     int(FLAGS.conv_strides.split(',')[i]),
                                     FLAGS.padding)
            biases = utils.bias_variable([conv_outputs[i + 1]])
            pre_activation = tf.nn.bias_add(conv, biases, data_format='NCHW')
            cur_conv = tf.nn.relu(pre_activation, name=scope.name)
            if FLAGS.pooling:
                convs += [
                    tf.nn.max_pool2d(cur_conv,
                                     ksize=2,
                                     strides=2,
                                     data_format='NCHW',
                                     padding='SAME')
                ]
            else:
                convs += [cur_conv]
            if FLAGS.verbose:
                tf.summary.histogram('activation', convs[-1])
    return convs[-1], conv_outputs[-1], position_grid