def conv_capsule_mat(input_tensor, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, stride=2, kernel_size=5, min_var=0.0005, final_beta=1.0): """Convolutional Capsule layer with Pose Matrices.""" print('caps conv stride: {}'.format(stride)) in_atom_sq = num_in_atoms * num_in_atoms with tf.variable_scope(layer_name): input_shape = tf.shape(input_tensor) _, _, _, in_height, in_width = input_tensor.get_shape() # This Variable will hold the state of the weights for the layer kernel = utils.weight_variable(shape=[ input_dim, kernel_size, kernel_size, num_in_atoms, output_dim * num_out_atoms ], stddev=0.3) # kernel = tf.clip_by_norm(kernel, 3.0, axes=[1, 2, 3]) activation_biases = utils.bias_variable( [1, 1, output_dim, 1, 1, 1, 1, 1], init_value=0.5, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1], init_value=.5, name='sigma_biases') with tf.name_scope('conv'): print('convi;') # input_tensor: [x,128,8, c1,c2] -> [x*128,8, c1,c2] print(input_tensor.get_shape()) input_tensor_reshaped = tf.reshape(input_tensor, [ input_shape[0] * input_dim * in_atom_sq, input_shape[3], input_shape[4], 1 ]) input_tensor_reshaped.set_shape((None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], 1)) input_act_reshaped = tf.reshape(input_activation, [ input_shape[0] * input_dim, input_shape[3], input_shape[4], 1 ]) input_act_reshaped.set_shape((None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], 1)) print(input_tensor_reshaped.get_shape()) # conv: [x*128,out*out_at, c3,c4] conv_patches = tf.extract_image_patches( images=input_tensor_reshaped, ksizes=[1, kernel_size, kernel_size, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID', ) act_patches = tf.extract_image_patches( images=input_act_reshaped, ksizes=[1, kernel_size, kernel_size, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID', ) o_height = (in_height - kernel_size) // stride + 1 o_width = (in_width - kernel_size) // stride + 1 patches = tf.reshape(conv_patches, (input_shape[0], input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patches.set_shape((None, input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2]) patch_split = tf.reshape( patch_trans, (input_dim, kernel_size, kernel_size, input_shape[0] * o_height * o_width * num_in_atoms, num_in_atoms)) patch_split.set_shape( (input_dim, kernel_size, kernel_size, None, num_in_atoms)) a_patches = tf.reshape(act_patches, (input_shape[0], input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size)) a_patches.set_shape((None, input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size)) with tf.name_scope('input_act'): utils.activation_summary( tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(a_patches, axis=1), axis=-1), axis=-1)) with tf.name_scope('Wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape(wx, (input_dim, kernel_size, kernel_size, input_shape[0], o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx.set_shape( (input_dim, kernel_size, kernel_size, None, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) utils.activation_summary(wx) with tf.name_scope('routing'): # Routing # logits: [x, 128, 10, c3, c4] logit_shape = [ input_dim, output_dim, 1, o_height, o_width, kernel_size, kernel_size ] activation, center = update_conv_routing( wx=wx, input_activation=a_patches, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms * num_out_atoms, input_dim=input_dim, num_routing=num_routing, output_dim=output_dim, min_var=min_var, final_beta=final_beta, ) # activations: [x, 10, 8, c3, c4] out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7]) out_center = tf.squeeze(center, axis=[1, 6, 7]) with tf.name_scope('center'): utils.activation_summary(out_center) return tf.sigmoid(out_activation), out_center
def conv_capsule_mat_fast( input_tensor, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, stride=2, kernel_size=5, min_var=0.0005, final_beta=1.0, ): """Convolutional Capsule layer with fast EM routing. Args: input_tensor: The input capsule features. input_activation: The input capsule activations. input_dim: Number of input capsule types. output_dim: Number of output capsule types. layer_name: Name of this layer, e.g. conv_capsule1 num_routing: Number of routing iterations. num_in_atoms: Number of features in each of the input capsules. num_out_atoms: Number of features in each of the output capsules. stride: Stride of the convolution. kernel_size: kernel size of the convolution. min_var: Minimum varience for each capsule to avoid NaNs. final_beta: beta for making the routing factors sharp. Returns: The final capsule center and activations. """ tf.logging.info('conv_capsule_mat %s', layer_name) tf.logging.info('input_shape %s', input_tensor.shape.as_list()) in_atom_sq = num_in_atoms * num_in_atoms with tf.variable_scope(layer_name): # This should be fully defined... # input_shape = tf.shape(input_tensor) input_shape = input_tensor.shape.as_list() batch, _, _, in_height, in_width = input_shape o_height = (in_height - kernel_size) // stride + 1 o_width = (in_width - kernel_size) // stride + 1 # This Variable will hold the state of the weights for the layer. kernel = utils.weight_variable(shape=[ input_dim, kernel_size, kernel_size, num_in_atoms, output_dim * num_out_atoms ], stddev=0.1) activation_biases = utils.bias_variable( [1, 1, output_dim, 1, 1, 1, 1, 1], init_value=0.2, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1], init_value=.5, name='sigma_biases') with utils.maybe_jit_scope(), tf.name_scope('conv'): input_tensor_reshaped = tf.reshape( input_tensor, [batch * input_dim * in_atom_sq, in_height, in_width, 1]) input_act_reshaped = tf.reshape( input_activation, [batch * input_dim, in_height, in_width, 1]) conv_patches = utils.kernel_tile(input_tensor_reshaped, kernel_size, stride) act_patches = utils.kernel_tile(input_act_reshaped, kernel_size, stride) patches = tf.reshape(conv_patches, (batch, input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2]) patch_split = tf.reshape( patch_trans, (input_dim, kernel_size, kernel_size, batch * o_height * o_width * num_in_atoms, num_in_atoms), name='patch_split') a_patches = tf.reshape(act_patches, (batch, input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size), name='a_patches') # Recompute Wx on backprop to save memory (perhaps redo patches as well?) # @tf.contrib.layers.recompute_grad def compute_wx(patch_split, kernel, is_recomputing=False): tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing) with utils.maybe_jit_scope(), tf.name_scope('wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape( wx, (input_dim, kernel_size, kernel_size, batch, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) return wx wx = compute_wx(patch_split, kernel.value()) with utils.maybe_jit_scope(): # Routing logit_shape = [ input_dim, output_dim, 1, o_height, o_width, kernel_size, kernel_size ] tf.logging.info('logit_shape: %s', logit_shape) activation, center = update_conv_routing_fast( wx=wx, input_activation=a_patches, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms * num_out_atoms, input_dim=input_dim, num_routing=num_routing, output_dim=output_dim, min_var=min_var, final_beta=4 * final_beta, stride=stride, layer_name=layer_name, ) with utils.maybe_jit_scope(): out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7], name='out_activation') out_center = tf.squeeze(center, axis=[1, 6, 7], name='out_center') out_activation = tf.sigmoid(out_activation) with tf.name_scope('center'): utils.activation_summary(out_center) return out_activation, out_center
def connector_capsule_mat(input_tensor, position_grid, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, leaky=False, final_beta=1.0, min_var=0.0005): """Final Capsule Layer with Pose Matrices and Shared connections.""" # One weight tensor for each capsule of the layer bellow: w: [8*128, 8*10] with tf.variable_scope(layer_name): # This Variable will hold the state of the weights for the layer with tf.name_scope('input_center_connector'): utils.activation_summary(input_tensor) weights = utils.weight_variable( [input_dim, num_out_atoms, output_dim * num_out_atoms], stddev=0.01) # weights = tf.clip_by_norm(weights, 1.0, axes=[1]) activation_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1], init_value=1.0, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1], init_value=2.0, name='sigma_biases') with tf.name_scope('Wx_plus_b'): # input_tensor: [x, 128, 8, h, w] input_shape = tf.shape(input_tensor) input_trans = tf.transpose(input_tensor, [1, 0, 3, 4, 2]) input_share = tf.reshape(input_trans, [input_dim, -1, num_in_atoms]) # input_expanded: [x, 128, 8, 1] wx_share = tf.matmul(input_share, weights) # sqr_num_out_atoms = num_out_atoms num_out_atoms *= num_out_atoms wx_trans = tf.reshape(wx_share, [ input_dim, input_shape[0], input_shape[3], input_shape[4], num_out_atoms, output_dim ]) wx_trans.set_shape( (input_dim, None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], num_out_atoms, output_dim)) h, w, _ = position_grid.get_shape() height = h width = w # t_pose = tf.transpose(position_grid, [2, 0, 1]) # t_pose_exp = tf.scatter_nd([[sqr_num_out_atoms -1], # [2 * sqr_num_out_atoms - 1]], t_pose, [num_out_atoms, height, width]) # pose_g_exp = tf.transpose(t_pose_exp, [1, 2, 0]) zero_grid = tf.zeros([height, width, num_out_atoms - 2]) pose_g_exp = tf.concat([position_grid, zero_grid], axis=2) pose_g = tf.expand_dims( tf.expand_dims(tf.expand_dims(pose_g_exp, -1), 0), 0) wx_posed = wx_trans + pose_g wx_posed_t = tf.transpose(wx_posed, [1, 0, 2, 3, 5, 4]) # Wx_reshaped: [x, 128, 10, 8] wx = tf.reshape(wx_posed_t, [ -1, input_dim * height * width, output_dim, num_out_atoms, 1, 1 ]) with tf.name_scope('routing'): # Routing # logits: [x, 128, 10] logit_shape = [input_dim * height * width, output_dim, 1, 1, 1] for _ in range(4): input_activation = tf.expand_dims(input_activation, axis=-1) activation, center = update_em_routing( wx=wx, input_activation=input_activation, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms, num_routing=num_routing, output_dim=output_dim, leaky=leaky, final_beta=final_beta / 4, min_var=min_var, ) out_activation = tf.squeeze(activation, axis=[1, 3, 4, 5]) out_center = tf.squeeze(center, axis=[1, 4, 5]) return tf.sigmoid(out_activation), out_center
def add_convs(features): """Stack Convolution layers.""" image_dim = features['height'] image_depth = features['depth'] image = features['images'] position_grid = tf.reshape( tf.constant(np.mgrid[(-image_dim // 2):((image_dim + 1) // 2), (-image_dim // 2):((image_dim + 1) // 2)], dtype=tf.float32) / 100.0, (1, 2, image_dim, image_dim)) if FLAGS.verbose_image: with tf.name_scope('input_reshape'): image_shaped_input = tf.reshape( image, [-1, image_dim, image_dim, image_depth]) tf.summary.image('input', image_shaped_input, 10) with tf.variable_scope('conv1') as scope: kernel = utils.weight_variable(shape=[ FLAGS.kernel_size, FLAGS.kernel_size, image_depth, FLAGS.num_start_conv ], stddev=5e-2) image_reshape = tf.reshape(image, [-1, image_depth, image_dim, image_dim]) if FLAGS.cpu_way: image_reshape = tf.transpose(image_reshape, [0, 2, 3, 1]) data_format = 'NHWC' strides = [1, FLAGS.stride_1, FLAGS.stride_1, 1] else: data_format = 'NCHW' strides = [1, 1, FLAGS.stride_1, FLAGS.stride_1] conv = tf.nn.conv2d(image_reshape, kernel, strides, padding=FLAGS.padding, data_format=data_format) biases = utils.bias_variable([FLAGS.num_start_conv]) pre_activation = tf.nn.bias_add(conv, biases, data_format=data_format) if FLAGS.cpu_way: pre_activation = tf.transpose(pre_activation, [0, 3, 1, 2]) position_grid = conv_pos(position_grid, FLAGS.kernel_size, FLAGS.stride_1, FLAGS.padding) conv1 = tf.nn.relu(pre_activation, name=scope.name) if FLAGS.verbose: tf.summary.histogram('activation', conv1) if FLAGS.pooling: pool1 = tf.nn.max_pool2d(conv1, ksize=2, strides=2, data_format='NCHW', padding='SAME') convs = [pool1] else: convs = [conv1] conv_outputs = [FLAGS.num_start_conv] for i in range(int(FLAGS.extra_conv)): conv_outputs += [int(FLAGS.conv_dims.split(',')[i])] with tf.variable_scope('conv{}'.format(i + 2)) as scope: kernel = utils.weight_variable(shape=[ int(FLAGS.conv_kernels.split(',')[i]), int(FLAGS.conv_kernels.split(',')[i]), conv_outputs[i], conv_outputs[i + 1] ], stddev=5e-2) conv = tf.nn.conv2d(convs[i], kernel, [ 1, 1, int(FLAGS.conv_strides.split(',')[i]), int(FLAGS.conv_strides.split(',')[i]) ], padding=FLAGS.padding, data_format='NCHW') position_grid = conv_pos(position_grid, int(FLAGS.conv_kernels.split(',')[i]), int(FLAGS.conv_strides.split(',')[i]), FLAGS.padding) biases = utils.bias_variable([conv_outputs[i + 1]]) pre_activation = tf.nn.bias_add(conv, biases, data_format='NCHW') cur_conv = tf.nn.relu(pre_activation, name=scope.name) if FLAGS.pooling: convs += [ tf.nn.max_pool2d(cur_conv, ksize=2, strides=2, data_format='NCHW', padding='SAME') ] else: convs += [cur_conv] if FLAGS.verbose: tf.summary.histogram('activation', convs[-1]) return convs[-1], conv_outputs[-1], position_grid