def fcm_loss(inputs, probs, num_iters=5, m=2): """ Args: inputs: [num_samples, 1, x_dims, 1] probs: [num_samples, num_clusters, 1, 1]. """ # centers: [1, num_clusters, x_dims] k = 1 b = 0 weights = [] delta_probs = [] for i in range(num_iters): probs_m = tf.pow(probs, m) centers = cl.reduce_sum(probs_m * inputs, axis=0, keepdims=True) / cl.reduce_sum(probs_m, axis=0, keepdims=True) # distance matrix with shape [num_samples, num_clusters, 1, 1] distance_matrix = cl.norm(inputs - centers, axis=(2, 3), keepdims=True) distance_matrix = tf.pow(distance_matrix, 2 / (m - 1)) probs_plus = 1 / (distance_matrix / cl.reduce_sum(distance_matrix, axis=1, keepdims=True)) delta_probs.append(tf.norm(probs_plus - probs)) weights.append(tf.exp(tf.cast(k * i + b, tf.float32))) probs = probs_plus weights = tf.stack(weights, axis=0) delta_probs = tf.stack(delta_probs, axis=0) loss = tf.reduce_sum(weights * delta_probs) / tf.reduce_sum(weights) return loss
def primaryCaps(inputs, filters, kernel_size, strides, out_caps_dims, method=None, name=None): '''Primary capsule layer. Args: inputs: [batch_size, in_height, in_width, in_channels]. filters: Integer, the dimensionality of the output space. kernel_size: kernel_size strides: strides out_caps_dims: A list of 2 integers. method: the method of calculating probability of entity existence(logistic, norm, None) Returns: pose: A 6-D tensor, [batch_size, out_height, out_width, filters] + out_caps_dims activation: A 4-D tensor, [batch_size, out_height, out_width, filters] ''' name = "primary_capsule" if name is None else name with tf.variable_scope(name): channels = filters * np.prod(out_caps_dims) channels = channels + filters if method == "logistic" else channels pose = tf.layers.conv2d(inputs, channels, kernel_size=kernel_size, strides=strides, activation=None) shape = cl.shape(pose, name="get_pose_shape") batch_size = shape[0] height = shape[1] width = shape[2] shape = [batch_size, height, width, filters] + out_caps_dims if method == 'logistic': # logistic activation unit pose, activation_logit = tf.split(pose, [channels - filters, filters], axis=-1) pose = tf.reshape(pose, shape=shape) activation = tf.sigmoid(activation_logit) elif method == 'norm' or method is None: pose = tf.reshape(pose, shape=shape) squash_on = -2 if out_caps_dims[-1] == 1 else [-2, -1] pose = cl.ops.squash(pose, axis=squash_on) activation = cl.norm(pose, axis=(-2, -1)) activation = tf.clip_by_value(activation, 1e-20, 1. - 1e-20) return (pose, activation)
def squash(inputs, axis=-2, ord="euclidean", name=None): """Squashing function. Args: inputs: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1] ord: Order of the norm. Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number yielding the corresponding p-norm. Default is 'euclidean' which is equivalent to Frobenius norm if tensor is a matrix and equivalent to 2-norm for vectors. Returns: A tensor with the same shape as inputs but squashed in `axis` dimension. """ name = "squashing" if name is None else name with tf.name_scope(name): norm = cl.norm(inputs, ord=ord, axis=axis, keepdims=True) norm_squared = tf.square(norm) scalar_factor = norm_squared / (1 + norm_squared) return scalar_factor * (inputs / norm)
def call(self, inputs): pose = self.conv2d(inputs) shape = cl.shape(pose, name="get_pose_shape") batch_size = shape[0] height = shape[1] width = shape[2] shape = [batch_size, height, width, self.filters] + self.out_caps_dims if self.method == 'logistic': # logistic activation unit num_or_size_splits = [self.channels - self.filters, self.filters] pose, activation_logit = tf.split(pose, num_or_size_splits, axis=-1) pose = tf.reshape(pose, shape=shape) activation = tf.sigmoid(activation_logit) elif self.method == 'norm' or self.method is None: pose = tf.reshape(pose, shape=shape) squash_on = -2 if self.out_caps_dims[-1] == 1 else [-2, -1] pose = cl.ops.squash(pose, axis=squash_on) activation = cl.norm(pose, axis=(-2, -1)) activation = tf.clip_by_value(activation, 1e-20, 1. - 1e-20) return (pose, activation)