def concat_fc(template, search, is_training, trainable=True, join_dim=128, mlp_num_outputs=1, mlp_num_layers=2, mlp_num_hidden=128, mlp_kwargs=None, scope=None): ''' Args: template: [b, h, w, c] search: [b, s, h, w, c] ''' with tf.variable_scope(scope, 'concat_fc'): template = cnn.as_tensor(template) search = cnn.as_tensor(search) # Instead of sliding-window concat, we do separate conv and sum the results. # Disable activation and normalizer. Perform these after the sum. kernel_size = template.value.shape[-3:-1].as_list() conv_kwargs = dict( padding='VALID', activation_fn=None, normalizer_fn=None, biases_initializer=None, # Disable bias because bnorm is performed later. ) with tf.variable_scope('template'): template = cnn.slim_conv2d(template, join_dim, kernel_size, scope='fc', **conv_kwargs) with tf.variable_scope('search'): search, restore = cnn.merge_batch_dims(search) search = cnn.slim_conv2d(search, join_dim, kernel_size, scope='fc', **conv_kwargs) search = restore(search) template = cnn.get_value(template) template = tf.expand_dims(template, 1) # This is a broadcasting addition. Receptive field in template not tracked. output = cnn.pixelwise(lambda search: search + template, search) output = cnn.pixelwise(partial(slim.batch_norm, is_training=is_training), output) output = cnn.pixelwise(tf.nn.relu, output) mlp_kwargs = mlp_kwargs or {} output, restore = cnn.merge_batch_dims(output) output = cnn.mlp(output, num_layers=mlp_num_layers, num_hidden=mlp_num_hidden, num_outputs=mlp_num_outputs, trainable=trainable, **mlp_kwargs) output = restore(output) return output
def mlp(template, search, is_training, trainable=True, num_layers=2, num_hidden=128, join_arch='depthwise_xcorr', join_params=None, scope='mlp_join'): '''Applies an MLP after another join function. Args: template: [b, ht, wt, c] search: [b, s, hs, ws, c] ''' with tf.variable_scope(scope, 'mlp_join'): join_params = join_params or {} join_fn = BY_NAME[join_arch] similarity = join_fn(template, search, is_training=is_training, trainable=trainable, **join_params) # similarity: [b, s, h, w, c] similarity, restore = cnn.merge_batch_dims(similarity) response = cnn.mlp(similarity, num_layers=num_layers, num_hidden=num_hidden, num_outputs=1, is_training=is_training, trainable=trainable) return restore(response)
def pre_conv(x): x = cnn.pixelwise(partial(slim.batch_norm, is_training=is_training), x) x = cnn.pixelwise(tf.nn.relu, x) x, restore = cnn.merge_batch_dims(x) x = cnn.slim_conv2d(x, pre_conv_output_dim, kernel_size, padding='VALID', activation_fn=None, normalizer_fn=slim.batch_norm, normalizer_params=dict(is_training=is_training), scope='conv') x = restore(x) return x
def distance(template, search, is_training, trainable=True, use_mean=False, use_batch_norm=False, learn_gain=False, gain_init=1, scope='distance'): ''' Args: template: [b, h, w, c] search: [b, s, h, w, c] ''' search = cnn.as_tensor(search) num_search_dims = len(search.value.shape) if num_search_dims != 5: raise ValueError('search should have 5 dims: {}'.format(num_search_dims)) with tf.variable_scope(scope, 'distance'): search = cnn.as_tensor(search) # Discard receptive field of template and get underlying tf.Tensor. template = cnn.get_value(template) num_channels = template.shape[-1].value template_size = template.shape[-3:-1].as_list() ones = tf.ones(template_size + [num_channels, 1], tf.float32) dot_xy = cnn.diag_xcorr(search, template) dot_xx = tf.reduce_sum(tf.square(template), axis=(-3, -2, -1), keepdims=True) if len(search.value.shape) == 5: dot_xx = tf.expand_dims(dot_xx, 1) sq_search = cnn.pixelwise(tf.square, search) sq_search, restore = cnn.merge_batch_dims(sq_search) dot_yy = cnn.nn_conv2d(sq_search, ones, strides=[1, 1, 1, 1], padding='VALID') dot_yy = restore(dot_yy) # (x - y)**2 = x**2 - 2 x y + y**2 # sq_dist = dot_xx - 2 * dot_xy + dot_yy sq_dist = cnn.pixelwise_binary( lambda dot_xy, dot_yy: dot_xx - 2 * dot_xy + dot_yy, dot_xy, dot_yy) sq_dist = cnn.pixelwise( lambda sq_dist: tf.reduce_sum(sq_dist, axis=-1, keepdims=True), sq_dist) if use_mean: # Take root-mean-square of difference. num_elems = np.prod(template.shape[-3:].as_list()) sq_dist = cnn.pixelwise(lambda sq_dist: (1 / tf.to_float(num_elems)) * sq_dist, sq_dist) dist = cnn.pixelwise(tf.sqrt, sq_dist) return _calibrate(dist, is_training, use_batch_norm, learn_gain, gain_init, trainable=trainable)
def all_pixel_pairs(template, search, is_training, trainable=True, operation='mul', reduce_channels=True, use_mean=True, use_batch_norm=False, learn_gain=False, gain_init=1, scope='all_pixel_pairs'): ''' Args: template: cnn.Tensor with shape [n, h_t, w_t, c] search: cnn.Tensor with shape [n, s, h_s, w_s, c] Returns: cnn.Tensor with shape [n, h_s, w_s, h_t * w_t] ''' with tf.variable_scope(scope, 'all_pixel_pairs'): template = cnn.as_tensor(template) search = cnn.as_tensor(search) template_size = template.value.shape[-3:-1].as_list() num_channels = template.value.shape[-1].value # Break template into 1x1 patches. # Then "convolve" (multiply) each with the search image. t = template.value s = search.value # template becomes: [n, 1, ..., 1, 1, h_t, w_t, c] # search becomes: [n, s, ..., h_s, w_s, 1, 1, c] t = tf.expand_dims(t, 1) t = helpers.expand_dims_n(t, -4, 2) s = helpers.expand_dims_n(s, -2, 2) if operation == 'mul': p = t * s elif operation == 'abs_diff': p = tf.abs(t - s) else: raise ValueError('unknown operation: "{}"'.format(operation)) # if reduce_channels: # if use_mean: # p = tf.reduce_mean(p, axis=-1, keepdims=True) # else: # p = tf.reduce_sum(p, axis=-1, keepdims=True) # Merge the spatial dimensions of the template into features. # response becomes: [n, ..., h_s, w_s, h_t * w_t * c] p, _ = helpers.merge_dims(p, -3, None) pairs = cnn.Tensor(p, search.fields) # TODO: This initialization could be too small? normalizer = 1 / (np.prod(template_size) ** 2 * num_channels) if use_mean else 1 weights_shape = template_size + [np.prod(template_size) * num_channels, 1] weights = tf.get_variable('weights', weights_shape, tf.float32, initializer=tf.constant_initializer(normalizer), trainable=trainable) # TODO: Support depthwise_conv2d (keep channels). pairs, restore = cnn.merge_batch_dims(pairs) response = cnn.nn_conv2d(pairs, weights, strides=[1, 1, 1, 1], padding='VALID') response = restore(response) return _calibrate(response, is_training, use_batch_norm, learn_gain, gain_init, trainable=trainable)