def vec_transformationByMat(poses, input_capsule_dim, input_capsule_num, output_capsule_dim, output_capsule_num, shared=True): inputs_poses_shape = poses.get_shape().as_list() poses = poses[..., tf.newaxis, :] poses = tf.tile(poses, [1, 1, output_capsule_num, 1]) if shared: kernel = capsule_utils._get_weights_wrapper(name='weights', shape=[ 1, 1, output_capsule_num, output_capsule_dim, input_capsule_dim ], weights_decay_factor=0.0) kernel = tf.tile(kernel, [inputs_poses_shape[0], input_capsule_num, 1, 1, 1]) else: kernel = capsule_utils._get_weights_wrapper(name='weights', shape=[ 1, input_capsule_num, output_capsule_num, output_capsule_dim, input_capsule_dim ], weights_decay_factor=0.0) kernel = tf.tile(kernel, [inputs_poses_shape[0], 1, 1, 1, 1]) tf.logging.info('poses: {}'.format(poses[..., tf.newaxis].get_shape())) tf.logging.info('kernel: {}'.format(kernel.get_shape())) u_hat_vecs = tf.squeeze(tf.matmul(kernel, poses[..., tf.newaxis]), axis=-1) u_hat_vecs = tf.transpose(u_hat_vecs, (0, 2, 1, 3)) return u_hat_vecs
def vec_transformationByConv(poses, input_capsule_dim, input_capsule_num, output_capsule_dim, output_capsule_num): kernel = capsule_utils._get_weights_wrapper( name='weights', shape=[1, input_capsule_dim, output_capsule_dim * output_capsule_num], weights_decay_factor=0.0) tf.logging.info('poses: {}'.format(poses.get_shape())) tf.logging.info('kernel: {}'.format(kernel.get_shape())) u_hat_vecs = tf.nn.conv1d(poses, kernel, stride=1, padding="VALID") u_hat_vecs = tf.reshape( u_hat_vecs, (-1, input_capsule_num, output_capsule_num, output_capsule_dim)) u_hat_vecs = tf.transpose(u_hat_vecs, (0, 2, 1, 3)) return u_hat_vecs
def capsules_init(inputs, shape, strides, padding, pose_shape, add_bias, name): with tf.variable_scope(name): poses = capsule_utils._conv2d_wrapper(inputs, shape=shape[0:-1] + [shape[-1] * pose_shape], strides=strides, padding=padding, add_bias=add_bias, activation_fn=None, name='pose_stacked') poses_shape = poses.get_shape().as_list() poses = tf.reshape( poses, [-1, poses_shape[1], poses_shape[2], shape[-1], pose_shape]) beta_a = capsule_utils._get_weights_wrapper(name='beta_a', shape=[1, shape[-1]]) poses = squash_v1(poses, axis=-1) activations = tf.sqrt(tf.reduce_sum(tf.square(poses), axis=-1)) + beta_a tf.logging.info("prim poses dimension:{}".format(poses.get_shape())) return poses, activations
def capsule_fc_layer(nets, output_capsule_num, iterations, name): with tf.variable_scope(name): poses, i_activations = nets input_pose_shape = poses.get_shape().as_list() u_hat_vecs = vec_transformationByConv( poses, input_pose_shape[-1], input_pose_shape[1], input_pose_shape[-1], output_capsule_num, ) tf.logging.info('votes shape: {}'.format(u_hat_vecs.get_shape())) beta_a = capsule_utils._get_weights_wrapper( name='beta_a', shape=[1, output_capsule_num]) poses, activations = routing(u_hat_vecs, beta_a, iterations, output_capsule_num, i_activations) tf.logging.info('capsule fc shape: {}'.format(poses.get_shape())) return poses, activations
def capsule_conv_layer(nets, shape, strides, iterations, name): with tf.variable_scope(name): poses, i_activations = nets inputs_poses_shape = poses.get_shape().as_list() hk_offsets = [[ (h_offset + k_offset) for k_offset in range(0, shape[0]) ] for h_offset in range(0, inputs_poses_shape[1] + 1 - shape[0], strides[1])] wk_offsets = [[ (w_offset + k_offset) for k_offset in range(0, shape[1]) ] for w_offset in range(0, inputs_poses_shape[2] + 1 - shape[1], strides[2])] inputs_poses_patches = tf.transpose(tf.gather( tf.gather(poses, hk_offsets, axis=1, name='gather_poses_height_kernel'), wk_offsets, axis=3, name='gather_poses_width_kernel'), perm=[0, 1, 3, 2, 4, 5, 6], name='inputs_poses_patches') tf.logging.info('i_poses_patches shape: {}'.format( inputs_poses_patches.get_shape())) inputs_poses_shape = inputs_poses_patches.get_shape().as_list() inputs_poses_patches = tf.reshape( inputs_poses_patches, [-1, shape[0] * shape[1] * shape[2], inputs_poses_shape[-1]]) i_activations_patches = tf.transpose(tf.gather( tf.gather(i_activations, hk_offsets, axis=1, name='gather_activations_height_kernel'), wk_offsets, axis=3, name='gather_activations_width_kernel'), perm=[0, 1, 3, 2, 4, 5], name='inputs_activations_patches') tf.logging.info('i_activations_patches shape: {}'.format( i_activations_patches.get_shape())) i_activations_patches = tf.reshape( i_activations_patches, [-1, shape[0] * shape[1] * shape[2]]) u_hat_vecs = vec_transformationByConv( inputs_poses_patches, inputs_poses_shape[-1], shape[0] * shape[1] * shape[2], inputs_poses_shape[-1], shape[3], ) tf.logging.info('capsule conv votes shape: {}'.format( u_hat_vecs.get_shape())) beta_a = capsule_utils._get_weights_wrapper(name='beta_a', shape=[1, shape[3]]) poses, activations = routing(u_hat_vecs, beta_a, iterations, shape[3], i_activations_patches) poses = tf.reshape(poses, [ inputs_poses_shape[0], inputs_poses_shape[1], inputs_poses_shape[2], shape[3], inputs_poses_shape[-1] ]) activations = tf.reshape(activations, [ inputs_poses_shape[0], inputs_poses_shape[1], inputs_poses_shape[2], shape[3] ]) nets = poses, activations tf.logging.info("capsule conv poses dimension:{}".format( poses.get_shape())) tf.logging.info("capsule conv activations dimension:{}".format( activations.get_shape())) return nets