def _build_capsule(input_tensor, input_atom, position_grid, num_classes): """Stack capsule layers.""" # input_tensor: [x, ch, atom, c1, c2] (64, 5x5, 2 conv1) print('hidden: ') print(input_tensor.get_shape()) conv_caps_act, conv_caps_center = em_layers.primary_caps( input_tensor, input_atom, FLAGS.num_prime_capsules, FLAGS.num_primary_atoms, ) with tf.name_scope('primary_act'): utils.activation_summary(conv_caps_act) with tf.name_scope('primary_center'): utils.activation_summary(conv_caps_center) last_dim = FLAGS.num_prime_capsules if FLAGS.extra_caps > 0: for i in range(FLAGS.extra_caps): conv_caps_act, conv_caps_center = em_layers.conv_capsule_mat( conv_caps_center, conv_caps_act, last_dim, int(FLAGS.caps_dims.split(',')[i]), 'convCaps{}'.format(i), FLAGS.routing_iteration, num_in_atoms=int(math.sqrt(FLAGS.num_primary_atoms)), num_out_atoms=int(math.sqrt(FLAGS.num_primary_atoms)), stride=int(FLAGS.caps_strides.split(',')[i]), kernel_size=int(FLAGS.caps_kernels.split(',')[i]), final_beta=FLAGS.final_beta, ) position_grid = simple_model.conv_pos( position_grid, int(FLAGS.caps_kernels.split(',')[i]), int(FLAGS.caps_strides.split(',')[i]), 'VALID') last_dim = int(FLAGS.caps_dims.split(',')[i]) print(conv_caps_center.get_shape()) print(conv_caps_act.get_shape()) capsule1_act = tf.layers.flatten(conv_caps_act) position_grid = tf.squeeze(position_grid, axis=[0]) position_grid = tf.transpose(position_grid, [1, 2, 0]) return em_layers.connector_capsule_mat( input_tensor=conv_caps_center, position_grid=position_grid, input_activation=capsule1_act, input_dim=last_dim, output_dim=num_classes, layer_name='capsule2', num_routing=FLAGS.routing_iteration, num_in_atoms=int(math.sqrt(FLAGS.num_primary_atoms)), num_out_atoms=int(math.sqrt(FLAGS.num_primary_atoms)), leaky=FLAGS.leaky, final_beta=FLAGS.final_beta, ), conv_caps_act
def conv_capsule_mat(input_tensor, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, stride=2, kernel_size=5, min_var=0.0005, final_beta=1.0): """Convolutional Capsule layer with Pose Matrices.""" print('caps conv stride: {}'.format(stride)) in_atom_sq = num_in_atoms * num_in_atoms with tf.variable_scope(layer_name): input_shape = tf.shape(input_tensor) _, _, _, in_height, in_width = input_tensor.get_shape() # This Variable will hold the state of the weights for the layer kernel = utils.weight_variable(shape=[ input_dim, kernel_size, kernel_size, num_in_atoms, output_dim * num_out_atoms ], stddev=0.3) # kernel = tf.clip_by_norm(kernel, 3.0, axes=[1, 2, 3]) activation_biases = utils.bias_variable( [1, 1, output_dim, 1, 1, 1, 1, 1], init_value=0.5, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1], init_value=.5, name='sigma_biases') with tf.name_scope('conv'): print('convi;') # input_tensor: [x,128,8, c1,c2] -> [x*128,8, c1,c2] print(input_tensor.get_shape()) input_tensor_reshaped = tf.reshape(input_tensor, [ input_shape[0] * input_dim * in_atom_sq, input_shape[3], input_shape[4], 1 ]) input_tensor_reshaped.set_shape((None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], 1)) input_act_reshaped = tf.reshape(input_activation, [ input_shape[0] * input_dim, input_shape[3], input_shape[4], 1 ]) input_act_reshaped.set_shape((None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], 1)) print(input_tensor_reshaped.get_shape()) # conv: [x*128,out*out_at, c3,c4] conv_patches = tf.extract_image_patches( images=input_tensor_reshaped, ksizes=[1, kernel_size, kernel_size, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID', ) act_patches = tf.extract_image_patches( images=input_act_reshaped, ksizes=[1, kernel_size, kernel_size, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID', ) o_height = (in_height - kernel_size) // stride + 1 o_width = (in_width - kernel_size) // stride + 1 patches = tf.reshape(conv_patches, (input_shape[0], input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patches.set_shape((None, input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2]) patch_split = tf.reshape( patch_trans, (input_dim, kernel_size, kernel_size, input_shape[0] * o_height * o_width * num_in_atoms, num_in_atoms)) patch_split.set_shape( (input_dim, kernel_size, kernel_size, None, num_in_atoms)) a_patches = tf.reshape(act_patches, (input_shape[0], input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size)) a_patches.set_shape((None, input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size)) with tf.name_scope('input_act'): utils.activation_summary( tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(a_patches, axis=1), axis=-1), axis=-1)) with tf.name_scope('Wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape(wx, (input_dim, kernel_size, kernel_size, input_shape[0], o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx.set_shape( (input_dim, kernel_size, kernel_size, None, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) utils.activation_summary(wx) with tf.name_scope('routing'): # Routing # logits: [x, 128, 10, c3, c4] logit_shape = [ input_dim, output_dim, 1, o_height, o_width, kernel_size, kernel_size ] activation, center = update_conv_routing( wx=wx, input_activation=a_patches, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms * num_out_atoms, input_dim=input_dim, num_routing=num_routing, output_dim=output_dim, min_var=min_var, final_beta=final_beta, ) # activations: [x, 10, 8, c3, c4] out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7]) out_center = tf.squeeze(center, axis=[1, 6, 7]) with tf.name_scope('center'): utils.activation_summary(out_center) return tf.sigmoid(out_activation), out_center
def connector_capsule_mat(input_tensor, position_grid, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, leaky=False, final_beta=1.0, min_var=0.0005): """Final Capsule Layer with Pose Matrices and Shared connections.""" # One weight tensor for each capsule of the layer bellow: w: [8*128, 8*10] with tf.variable_scope(layer_name): # This Variable will hold the state of the weights for the layer with tf.name_scope('input_center_connector'): utils.activation_summary(input_tensor) weights = utils.weight_variable( [input_dim, num_out_atoms, output_dim * num_out_atoms], stddev=0.01) # weights = tf.clip_by_norm(weights, 1.0, axes=[1]) activation_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1], init_value=1.0, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1], init_value=2.0, name='sigma_biases') with tf.name_scope('Wx_plus_b'): # input_tensor: [x, 128, 8, h, w] input_shape = tf.shape(input_tensor) input_trans = tf.transpose(input_tensor, [1, 0, 3, 4, 2]) input_share = tf.reshape(input_trans, [input_dim, -1, num_in_atoms]) # input_expanded: [x, 128, 8, 1] wx_share = tf.matmul(input_share, weights) # sqr_num_out_atoms = num_out_atoms num_out_atoms *= num_out_atoms wx_trans = tf.reshape(wx_share, [ input_dim, input_shape[0], input_shape[3], input_shape[4], num_out_atoms, output_dim ]) wx_trans.set_shape( (input_dim, None, input_tensor.get_shape()[3], input_tensor.get_shape()[4], num_out_atoms, output_dim)) h, w, _ = position_grid.get_shape() height = h width = w # t_pose = tf.transpose(position_grid, [2, 0, 1]) # t_pose_exp = tf.scatter_nd([[sqr_num_out_atoms -1], # [2 * sqr_num_out_atoms - 1]], t_pose, [num_out_atoms, height, width]) # pose_g_exp = tf.transpose(t_pose_exp, [1, 2, 0]) zero_grid = tf.zeros([height, width, num_out_atoms - 2]) pose_g_exp = tf.concat([position_grid, zero_grid], axis=2) pose_g = tf.expand_dims( tf.expand_dims(tf.expand_dims(pose_g_exp, -1), 0), 0) wx_posed = wx_trans + pose_g wx_posed_t = tf.transpose(wx_posed, [1, 0, 2, 3, 5, 4]) # Wx_reshaped: [x, 128, 10, 8] wx = tf.reshape(wx_posed_t, [ -1, input_dim * height * width, output_dim, num_out_atoms, 1, 1 ]) with tf.name_scope('routing'): # Routing # logits: [x, 128, 10] logit_shape = [input_dim * height * width, output_dim, 1, 1, 1] for _ in range(4): input_activation = tf.expand_dims(input_activation, axis=-1) activation, center = update_em_routing( wx=wx, input_activation=input_activation, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms, num_routing=num_routing, output_dim=output_dim, leaky=leaky, final_beta=final_beta / 4, min_var=min_var, ) out_activation = tf.squeeze(activation, axis=[1, 3, 4, 5]) out_center = tf.squeeze(center, axis=[1, 4, 5]) return tf.sigmoid(out_activation), out_center
def update_em_routing(wx, input_activation, activation_biases, sigma_biases, logit_shape, num_out_atoms, num_routing, output_dim, leaky, final_beta, min_var): """Fully connected routing with EM for Mixture of Gaussians.""" # Wx: [batch, indim, outdim, outatom, height, width] # logit_shape: [indim, outdim, 1, height, width] # input_activations: [batch, indim, 1, 1, 1, 1] # activation_biases: [1, 1, outdim, 1, height, width] # prior = utils.bias_variable([1] + logit_shape, name='prior') update = tf.fill( tf.stack([ tf.shape(input_activation)[0], logit_shape[0], logit_shape[1], logit_shape[2], logit_shape[3], logit_shape[4] ]), 0.0) out_activation = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4] ]), 0.0) out_center = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, num_out_atoms, logit_shape[3], logit_shape[4] ]), 0.0) def _body(i, update, activation, center): """Body of the EM while loop.""" del activation beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) # beta = final_beta # route: [outdim, height?, width?, batch, indim] if leaky: posterior = layers.leaky_routing(update, output_dim) else: posterior = tf.nn.softmax(update, dim=2) vote_conf = posterior * input_activation # masses: [batch, 1, outdim, 1, height, width] masses = tf.reduce_sum(vote_conf, axis=1, keep_dims=True) + 0.00001 preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] center = .9 * tf.reduce_sum(preactivate_unrolled, axis=1, keep_dims=True) / masses + .1 * center noise = (wx - center) * (wx - center) variance = min_var + tf.reduce_sum( vote_conf * noise, axis=1, keep_dims=True) / masses log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True) log_2pi = tf.log(2 * math.pi) win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0)) logit = beta * (win - activation_biases * 5000) activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) # return activation, center log_det_sigma = tf.reduce_sum(log_variance, axis=3, keep_dims=True) sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = activation_update - sigma_update - exp_update return (prior_update, logit, center) # activations = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # centers = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates.write(0, prior_update) for i in range(num_routing): update, out_activation, out_center = _body(i, update, out_activation, out_center) # for j in range(num_routing): # _, prior_update, out_activation, out_center = _body( # i, prior_update, start_activation, start_center) with tf.name_scope('out_activation'): utils.activation_summary(tf.sigmoid(out_activation)) with tf.name_scope('noise'): utils.variable_summaries((wx - out_center) * (wx - out_center)) with tf.name_scope('Wx'): utils.variable_summaries(wx) # for i in range(num_routing): # utils.activation_summary(activations.read(i)) # return activations.read(num_routing - 1), centers.read(num_routing - 1) return out_activation, out_center
def update_conv_routing(wx, input_activation, activation_biases, sigma_biases, logit_shape, num_out_atoms, input_dim, num_routing, output_dim, final_beta, min_var): """Convolutional Routing with EM for Mixture of Gaussians.""" # Wx: [batch, indim, outdim, outatom, height, width, k, k] # logit_shape: [indim, outdim, 1, height, width, k, k] # input_activations: [batch, indim, 1, 1, height, width, k, k] # activation_biases: [1, 1, outdim, 1, height, width] # prior = utils.bias_variable([1] + logit_shape, name='prior') post = tf.nn.softmax(tf.fill( tf.stack([ tf.shape(input_activation)[0], logit_shape[0], logit_shape[1], logit_shape[2], logit_shape[3], logit_shape[4], logit_shape[5], logit_shape[6] ]), 0.0), dim=2) out_activation = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) out_center = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, num_out_atoms, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) out_mass = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) n = logit_shape[3] k = logit_shape[5] def _body(i, posterior, activation, center, masses): """Body of the EM while loop.""" del activation beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) # beta = final_beta # route: [outdim, height?, width?, batch, indim] vote_conf = posterior * input_activation # masses: [batch, 1, outdim, 1, height, width, 1, 1] masses = tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( vote_conf, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) + 0.0000001 preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] center = .9 * tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( preactivate_unrolled, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) / masses + .1 * center noise = (wx - center) * (wx - center) variance = min_var + tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( vote_conf * noise, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) / masses log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True) log_2pi = tf.log(2 * math.pi) win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0)) logit = beta * (win - activation_biases * 5000) activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) # return activation, center log_det_sigma = -1 * p_i sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = activation_update - sigma_update - exp_update max_prior_update = tf.reduce_max(tf.reduce_max(tf.reduce_max( tf.reduce_max(prior_update, axis=-1, keep_dims=True), axis=-2, keep_dims=True), axis=-3, keep_dims=True), axis=-4, keep_dims=True) prior_normal = tf.add(prior_update, -1 * max_prior_update) prior_exp = tf.exp(prior_normal) t_prior = tf.transpose(prior_exp, [0, 1, 2, 3, 4, 6, 5, 7]) c_prior = tf.reshape(t_prior, [-1, n * k, n * k, 1]) pad_prior = tf.pad(c_prior, [[0, 0], [(k - 1) * (k - 1), (k - 1) * (k - 1)], [(k - 1) * (k - 1), (k - 1) * (k - 1)], [0, 0]], 'CONSTANT') patch_prior = tf.extract_image_patches(images=pad_prior, ksizes=[1, k, k, 1], strides=[1, k, k, 1], rates=[1, k - 1, k - 1, 1], padding='VALID') sum_prior = tf.reduce_sum(patch_prior, axis=-1, keep_dims=True) sum_prior_patch = tf.extract_image_patches(images=sum_prior, ksizes=[1, k, k, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID') sum_prior_reshape = tf.reshape( sum_prior_patch, [-1, input_dim, output_dim, 1, n, n, k, k]) + 0.0000001 posterior = prior_exp / sum_prior_reshape return (posterior, logit, center, masses) # activations = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # centers = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates.write(0, prior_update) for i in range(num_routing): post, out_activation, out_center, out_mass = _body( i, post, out_activation, out_center, out_mass) # for j in range(num_routing): # _, prior_update, out_activation, out_center = _body( # i, prior_update, start_activation, start_center) with tf.name_scope('out_activation'): utils.activation_summary(tf.sigmoid(out_activation)) with tf.name_scope('masses'): utils.activation_summary(tf.sigmoid(out_mass)) with tf.name_scope('posterior'): utils.activation_summary(post) with tf.name_scope('noise'): utils.variable_summaries((wx - out_center) * (wx - out_center)) with tf.name_scope('Wx'): utils.variable_summaries(wx) # for i in range(num_routing): # utils.activation_summary(activations.read(i)) # return activations.read(num_routing - 1), centers.read(num_routing - 1) return out_activation, out_center
def update_conv_routing_fast(wx, input_activation, activation_biases, sigma_biases, logit_shape, num_out_atoms, input_dim, num_routing, output_dim, final_beta, min_var, stride, layer_name): """Fast Convolutional Routing with EM for Mixture of Gaussians. The main difference with conv_routing is replacing extract_image_patches with utils.kernel_tile which uses a special conv-deconv operation. Args: wx: [batch, indim, outdim, outatom, height, width, kernel, kernel] input_activation: [batch, indim, 1, 1, height, width, kernel, kernel] activation_biases: [1, 1, outdim, 1, height, width] sigma_biases: [1, 1, outdim, 1, height, width] logit_shape: [indim, outdim, 1, height, width, kernel, kernel] num_out_atoms: number of atoms in each capsule, e.g. 9 or 16. input_dim: number of input capsule types, e.g. 32. num_routing: number of routing iterations, e.g. 3. output_dim: number of output capsule types, e.g. 32. final_beta: the temperature for making routing factors sharper. min_var: minimum variance for each capsule to avoid NaNs. stride: the stride with which wx was calculated, e.g. 2 or 1. layer_name: the name of this layer, e.g. conv_capsule1. Returns: out_activation and out_center: final activation and capsule values. """ # prior = utils.bias_variable([1] + logit_shape, name='prior') tf.logging.info( 'update_conv_routing_fast: Wx=%s act=%s act_bias=%s sigma_bias=%s logit_shape=%s', wx, input_activation, activation_biases, sigma_biases, logit_shape) with tf.name_scope('update_conv_routing_fast'): # With known shapes, these could all be replaced with tf.zeros with tf.name_scope('start_posterior'): start_posterior = tf.nn.softmax(tf.fill( tf.stack([ tf.shape(input_activation)[0], logit_shape[0], logit_shape[1], logit_shape[2], logit_shape[3], logit_shape[4], logit_shape[5], logit_shape[6] ]), 0.0), dim=2) with tf.name_scope('start_center'): start_center = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, num_out_atoms, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) b = tf.shape(input_activation)[0] c = output_dim h = logit_shape[3] k = logit_shape[5] s = stride ih = h + (h - 1) * (s - 1) + (k - 1) tile_filter = np.zeros(shape=[k, k, 1, k * k], dtype=np.float32) for i in range(k): for j in range(k): tile_filter[i, j, :, i * k + j] = 1.0 # Body of routing loop. def _body(i, posterior, center, wx, activation_biases, sigma_biases, input_activation, tile_filter): """Body of EM while loop.""" tf.logging.info(' Wx: %s', wx) beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) posterior = tf.Print(posterior, [ layer_name, i, h, ih, tf.reduce_min(posterior), tf.reduce_max(posterior) ], message='posterior') # route: [outdim, height?, width?, batch, indim] with tf.name_scope('vote_conf'): vote_conf = posterior * input_activation vote_conf = tf.maximum(vote_conf, 0.0) # masses: [batch, 1, outdim, 1, height, width, 1, 1] with tf.name_scope('masses'): masses = tf.reduce_sum(vote_conf, axis=[1, -1, -2], keepdims=True, name='masses_calculation') + 0.0000001 with tf.name_scope('preactivate_unrolled'): preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] with tf.name_scope('center'): center = .9 * tf.reduce_sum( preactivate_unrolled, axis=[1, -1, -2], keepdims=True) / masses + .1 * center # Rematerialization to save GPU memory. (+22ms/-1.6GB) # @tf.contrib.layers.recompute_grad def compute_noise_and_variance(wx, center, vote_conf, masses): noise = tf.squared_difference(wx, center) variance = min_var + tf.reduce_sum( vote_conf * noise, axis=[1, -1, -2], keepdims=True, name='variance_calculation') / masses return noise, variance with tf.name_scope('compute_noise_and_variance'): noise, variance = compute_noise_and_variance( wx, center, vote_conf, masses) with tf.name_scope('win'): log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keepdims=True) log_2pi = tf.log(2 * math.pi) sigma_b = tf.log(sigma_biases * sigma_biases + min_var) win = masses * (p_i - num_out_atoms * (sigma_b + log_2pi + 1.0)) with tf.name_scope('logit'): logit = beta * (win - activation_biases * 50 * num_out_atoms) with tf.name_scope('activation_update'): activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) with tf.name_scope('sigma_update'): log_det_sigma = -1 * p_i sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 with tf.name_scope('exp_update'): exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = tf.subtract(activation_update - sigma_update, exp_update, name='prior_update_sub') max_prior_update = tf.reduce_max(prior_update, axis=[2, 3, 4, 5, 6, 7], keepdims=True, name='max_prior_opdate') prior_normal = tf.add(prior_update, -1 * max_prior_update) prior_exp = tf.exp(prior_normal) prior_exp_out = tf.reduce_sum(prior_exp, axis=2, keepdims=True, name='prior_exp_out') prior_exp_reshape = tf.reshape(prior_exp_out, [-1, h, h, k * k], name='prior_exp_reshape') sum_prior = tf.nn.conv2d_transpose(prior_exp_reshape, tile_filter, output_shape=[b * c, ih, ih, 1], strides=[1, s, s, 1], padding='VALID') sum_prior = tf.maximum(1e-6, sum_prior) sum_prior_patch = utils.kernel_tile(sum_prior, k, s, 1, name='sum_prior_patch') with utils.maybe_jit_scope(), tf.name_scope('posterior'): sum_prior_reshape = tf.reshape( sum_prior_patch, [-1, input_dim, 1, 1, h, h, k, k]) posterior = prior_exp / sum_prior_reshape return (i + 1, posterior, logit, center, masses) posterior, center = start_posterior, start_center for j in range(num_routing): with tf.name_scope('iter{}'.format(j)): tf.logging.info('iteration %d %s', j, '=' * 80) jj = tf.constant(j, dtype=tf.int32) _, posterior, activation, center, mass = _body( jj, posterior, center, wx, activation_biases, sigma_biases, input_activation, tile_filter) post, out_activation, out_center, out_mass = posterior, activation, center, mass with tf.name_scope('out_activation'): utils.activation_summary(tf.sigmoid(out_activation)) with tf.name_scope('masses'): utils.activation_summary(tf.sigmoid(out_mass)) with tf.name_scope('posterior'): utils.activation_summary(post) return out_activation, out_center
def conv_capsule_mat_fast( input_tensor, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, stride=2, kernel_size=5, min_var=0.0005, final_beta=1.0, ): """Convolutional Capsule layer with fast EM routing. Args: input_tensor: The input capsule features. input_activation: The input capsule activations. input_dim: Number of input capsule types. output_dim: Number of output capsule types. layer_name: Name of this layer, e.g. conv_capsule1 num_routing: Number of routing iterations. num_in_atoms: Number of features in each of the input capsules. num_out_atoms: Number of features in each of the output capsules. stride: Stride of the convolution. kernel_size: kernel size of the convolution. min_var: Minimum varience for each capsule to avoid NaNs. final_beta: beta for making the routing factors sharp. Returns: The final capsule center and activations. """ tf.logging.info('conv_capsule_mat %s', layer_name) tf.logging.info('input_shape %s', input_tensor.shape.as_list()) in_atom_sq = num_in_atoms * num_in_atoms with tf.variable_scope(layer_name): # This should be fully defined... # input_shape = tf.shape(input_tensor) input_shape = input_tensor.shape.as_list() batch, _, _, in_height, in_width = input_shape o_height = (in_height - kernel_size) // stride + 1 o_width = (in_width - kernel_size) // stride + 1 # This Variable will hold the state of the weights for the layer. kernel = utils.weight_variable(shape=[ input_dim, kernel_size, kernel_size, num_in_atoms, output_dim * num_out_atoms ], stddev=0.1) activation_biases = utils.bias_variable( [1, 1, output_dim, 1, 1, 1, 1, 1], init_value=0.2, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1], init_value=.5, name='sigma_biases') with utils.maybe_jit_scope(), tf.name_scope('conv'): input_tensor_reshaped = tf.reshape( input_tensor, [batch * input_dim * in_atom_sq, in_height, in_width, 1]) input_act_reshaped = tf.reshape( input_activation, [batch * input_dim, in_height, in_width, 1]) conv_patches = utils.kernel_tile(input_tensor_reshaped, kernel_size, stride) act_patches = utils.kernel_tile(input_act_reshaped, kernel_size, stride) patches = tf.reshape(conv_patches, (batch, input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2]) patch_split = tf.reshape( patch_trans, (input_dim, kernel_size, kernel_size, batch * o_height * o_width * num_in_atoms, num_in_atoms), name='patch_split') a_patches = tf.reshape(act_patches, (batch, input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size), name='a_patches') # Recompute Wx on backprop to save memory (perhaps redo patches as well?) # @tf.contrib.layers.recompute_grad def compute_wx(patch_split, kernel, is_recomputing=False): tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing) with utils.maybe_jit_scope(), tf.name_scope('wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape( wx, (input_dim, kernel_size, kernel_size, batch, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) return wx wx = compute_wx(patch_split, kernel.value()) with utils.maybe_jit_scope(): # Routing logit_shape = [ input_dim, output_dim, 1, o_height, o_width, kernel_size, kernel_size ] tf.logging.info('logit_shape: %s', logit_shape) activation, center = update_conv_routing_fast( wx=wx, input_activation=a_patches, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms * num_out_atoms, input_dim=input_dim, num_routing=num_routing, output_dim=output_dim, min_var=min_var, final_beta=4 * final_beta, stride=stride, layer_name=layer_name, ) with utils.maybe_jit_scope(): out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7], name='out_activation') out_center = tf.squeeze(center, axis=[1, 6, 7], name='out_center') out_activation = tf.sigmoid(out_activation) with tf.name_scope('center'): utils.activation_summary(out_center) return out_activation, out_center