Example #1
0
def concat_fc(template, search, is_training,
              trainable=True,
              join_dim=128,
              mlp_num_outputs=1,
              mlp_num_layers=2,
              mlp_num_hidden=128,
              mlp_kwargs=None,
              scope=None):
    '''
    Args:
        template: [b, h, w, c]
        search: [b, s, h, w, c]
    '''
    with tf.variable_scope(scope, 'concat_fc'):
        template = cnn.as_tensor(template)
        search = cnn.as_tensor(search)

        # Instead of sliding-window concat, we do separate conv and sum the results.
        # Disable activation and normalizer. Perform these after the sum.
        kernel_size = template.value.shape[-3:-1].as_list()
        conv_kwargs = dict(
            padding='VALID',
            activation_fn=None,
            normalizer_fn=None,
            biases_initializer=None,  # Disable bias because bnorm is performed later.
        )
        with tf.variable_scope('template'):
            template = cnn.slim_conv2d(template, join_dim, kernel_size,
                                       scope='fc', **conv_kwargs)
        with tf.variable_scope('search'):
            search, restore = cnn.merge_batch_dims(search)
            search = cnn.slim_conv2d(search, join_dim, kernel_size,
                                     scope='fc', **conv_kwargs)
            search = restore(search)

        template = cnn.get_value(template)
        template = tf.expand_dims(template, 1)
        # This is a broadcasting addition. Receptive field in template not tracked.
        output = cnn.pixelwise(lambda search: search + template, search)
        output = cnn.pixelwise(partial(slim.batch_norm, is_training=is_training), output)
        output = cnn.pixelwise(tf.nn.relu, output)

        mlp_kwargs = mlp_kwargs or {}
        output, restore = cnn.merge_batch_dims(output)
        output = cnn.mlp(output,
                         num_layers=mlp_num_layers,
                         num_hidden=mlp_num_hidden,
                         num_outputs=mlp_num_outputs,
                         trainable=trainable, **mlp_kwargs)
        output = restore(output)
        return output
Example #2
0
def mlp(template, search, is_training,
        trainable=True,
        num_layers=2,
        num_hidden=128,
        join_arch='depthwise_xcorr',
        join_params=None,
        scope='mlp_join'):
    '''Applies an MLP after another join function.

    Args:
        template: [b, ht, wt, c]
        search: [b, s, hs, ws, c]
    '''
    with tf.variable_scope(scope, 'mlp_join'):
        join_params = join_params or {}
        join_fn = BY_NAME[join_arch]
        similarity = join_fn(template, search,
                             is_training=is_training,
                             trainable=trainable,
                             **join_params)
        # similarity: [b, s, h, w, c]
        similarity, restore = cnn.merge_batch_dims(similarity)
        response = cnn.mlp(similarity,
                           num_layers=num_layers,
                           num_hidden=num_hidden,
                           num_outputs=1,
                           is_training=is_training,
                           trainable=trainable)
        return restore(response)
Example #3
0
 def pre_conv(x):
     x = cnn.pixelwise(partial(slim.batch_norm, is_training=is_training), x)
     x = cnn.pixelwise(tf.nn.relu, x)
     x, restore = cnn.merge_batch_dims(x)
     x = cnn.slim_conv2d(x, pre_conv_output_dim, kernel_size,
                         padding='VALID',
                         activation_fn=None,
                         normalizer_fn=slim.batch_norm,
                         normalizer_params=dict(is_training=is_training),
                         scope='conv')
     x = restore(x)
     return x
Example #4
0
def distance(template, search, is_training,
             trainable=True,
             use_mean=False,
             use_batch_norm=False,
             learn_gain=False,
             gain_init=1,
             scope='distance'):
    '''
    Args:
        template: [b, h, w, c]
        search: [b, s, h, w, c]
    '''
    search = cnn.as_tensor(search)
    num_search_dims = len(search.value.shape)
    if num_search_dims != 5:
        raise ValueError('search should have 5 dims: {}'.format(num_search_dims))

    with tf.variable_scope(scope, 'distance'):
        search = cnn.as_tensor(search)
        # Discard receptive field of template and get underlying tf.Tensor.
        template = cnn.get_value(template)

        num_channels = template.shape[-1].value
        template_size = template.shape[-3:-1].as_list()
        ones = tf.ones(template_size + [num_channels, 1], tf.float32)

        dot_xy = cnn.diag_xcorr(search, template)
        dot_xx = tf.reduce_sum(tf.square(template), axis=(-3, -2, -1), keepdims=True)
        if len(search.value.shape) == 5:
            dot_xx = tf.expand_dims(dot_xx, 1)
        sq_search = cnn.pixelwise(tf.square, search)
        sq_search, restore = cnn.merge_batch_dims(sq_search)
        dot_yy = cnn.nn_conv2d(sq_search, ones, strides=[1, 1, 1, 1], padding='VALID')
        dot_yy = restore(dot_yy)
        # (x - y)**2 = x**2 - 2 x y + y**2
        # sq_dist = dot_xx - 2 * dot_xy + dot_yy
        sq_dist = cnn.pixelwise_binary(
            lambda dot_xy, dot_yy: dot_xx - 2 * dot_xy + dot_yy, dot_xy, dot_yy)
        sq_dist = cnn.pixelwise(
            lambda sq_dist: tf.reduce_sum(sq_dist, axis=-1, keepdims=True), sq_dist)
        if use_mean:
            # Take root-mean-square of difference.
            num_elems = np.prod(template.shape[-3:].as_list())
            sq_dist = cnn.pixelwise(lambda sq_dist: (1 / tf.to_float(num_elems)) * sq_dist, sq_dist)
        dist = cnn.pixelwise(tf.sqrt, sq_dist)
        return _calibrate(dist, is_training, use_batch_norm, learn_gain, gain_init,
                          trainable=trainable)
Example #5
0
def all_pixel_pairs(template, search, is_training,
                    trainable=True,
                    operation='mul',
                    reduce_channels=True,
                    use_mean=True,
                    use_batch_norm=False,
                    learn_gain=False,
                    gain_init=1,
                    scope='all_pixel_pairs'):
    '''
    Args:
        template: cnn.Tensor with shape [n, h_t, w_t, c]
        search: cnn.Tensor with shape [n, s, h_s, w_s, c]

    Returns:
        cnn.Tensor with shape [n, h_s, w_s, h_t * w_t]
    '''
    with tf.variable_scope(scope, 'all_pixel_pairs'):
        template = cnn.as_tensor(template)
        search = cnn.as_tensor(search)
        template_size = template.value.shape[-3:-1].as_list()
        num_channels = template.value.shape[-1].value

        # Break template into 1x1 patches.
        # Then "convolve" (multiply) each with the search image.
        t = template.value
        s = search.value
        # template becomes: [n, 1, ...,   1,   1, h_t, w_t, c]
        # search becomes:   [n, s, ..., h_s, w_s,   1,   1, c]
        t = tf.expand_dims(t, 1)
        t = helpers.expand_dims_n(t, -4, 2)
        s = helpers.expand_dims_n(s, -2, 2)
        if operation == 'mul':
            p = t * s
        elif operation == 'abs_diff':
            p = tf.abs(t - s)
        else:
            raise ValueError('unknown operation: "{}"'.format(operation))

        # if reduce_channels:
        #     if use_mean:
        #         p = tf.reduce_mean(p, axis=-1, keepdims=True)
        #     else:
        #         p = tf.reduce_sum(p, axis=-1, keepdims=True)
        # Merge the spatial dimensions of the template into features.
        # response becomes: [n, ..., h_s, w_s, h_t * w_t * c]
        p, _ = helpers.merge_dims(p, -3, None)
        pairs = cnn.Tensor(p, search.fields)

        # TODO: This initialization could be too small?
        normalizer = 1 / (np.prod(template_size) ** 2 * num_channels) if use_mean else 1
        weights_shape = template_size + [np.prod(template_size) * num_channels, 1]
        weights = tf.get_variable('weights', weights_shape, tf.float32,
                                  initializer=tf.constant_initializer(normalizer),
                                  trainable=trainable)
        # TODO: Support depthwise_conv2d (keep channels).
        pairs, restore = cnn.merge_batch_dims(pairs)
        response = cnn.nn_conv2d(pairs, weights, strides=[1, 1, 1, 1], padding='VALID')
        response = restore(response)

        return _calibrate(response, is_training, use_batch_norm, learn_gain, gain_init,
                          trainable=trainable)