def compute_wx(patch_split, kernel, is_recomputing=False): tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing) with utils.maybe_jit_scope(), tf.name_scope('wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape( wx, (input_dim, kernel_size, kernel_size, batch, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) return wx
def _body(i, posterior, center, wx, activation_biases, sigma_biases, input_activation, tile_filter): """Body of EM while loop.""" tf.logging.info(' Wx: %s', wx) beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) posterior = tf.Print(posterior, [ layer_name, i, h, ih, tf.reduce_min(posterior), tf.reduce_max(posterior) ], message='posterior') # route: [outdim, height?, width?, batch, indim] with tf.name_scope('vote_conf'): vote_conf = posterior * input_activation vote_conf = tf.maximum(vote_conf, 0.0) # masses: [batch, 1, outdim, 1, height, width, 1, 1] with tf.name_scope('masses'): masses = tf.reduce_sum(vote_conf, axis=[1, -1, -2], keepdims=True, name='masses_calculation') + 0.0000001 with tf.name_scope('preactivate_unrolled'): preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] with tf.name_scope('center'): center = .9 * tf.reduce_sum( preactivate_unrolled, axis=[1, -1, -2], keepdims=True) / masses + .1 * center # Rematerialization to save GPU memory. (+22ms/-1.6GB) # @tf.contrib.layers.recompute_grad def compute_noise_and_variance(wx, center, vote_conf, masses): noise = tf.squared_difference(wx, center) variance = min_var + tf.reduce_sum( vote_conf * noise, axis=[1, -1, -2], keepdims=True, name='variance_calculation') / masses return noise, variance with tf.name_scope('compute_noise_and_variance'): noise, variance = compute_noise_and_variance( wx, center, vote_conf, masses) with tf.name_scope('win'): log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keepdims=True) log_2pi = tf.log(2 * math.pi) sigma_b = tf.log(sigma_biases * sigma_biases + min_var) win = masses * (p_i - num_out_atoms * (sigma_b + log_2pi + 1.0)) with tf.name_scope('logit'): logit = beta * (win - activation_biases * 50 * num_out_atoms) with tf.name_scope('activation_update'): activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) with tf.name_scope('sigma_update'): log_det_sigma = -1 * p_i sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 with tf.name_scope('exp_update'): exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = tf.subtract(activation_update - sigma_update, exp_update, name='prior_update_sub') max_prior_update = tf.reduce_max(prior_update, axis=[2, 3, 4, 5, 6, 7], keepdims=True, name='max_prior_opdate') prior_normal = tf.add(prior_update, -1 * max_prior_update) prior_exp = tf.exp(prior_normal) prior_exp_out = tf.reduce_sum(prior_exp, axis=2, keepdims=True, name='prior_exp_out') prior_exp_reshape = tf.reshape(prior_exp_out, [-1, h, h, k * k], name='prior_exp_reshape') sum_prior = tf.nn.conv2d_transpose(prior_exp_reshape, tile_filter, output_shape=[b * c, ih, ih, 1], strides=[1, s, s, 1], padding='VALID') sum_prior = tf.maximum(1e-6, sum_prior) sum_prior_patch = utils.kernel_tile(sum_prior, k, s, 1, name='sum_prior_patch') with utils.maybe_jit_scope(), tf.name_scope('posterior'): sum_prior_reshape = tf.reshape( sum_prior_patch, [-1, input_dim, 1, 1, h, h, k, k]) posterior = prior_exp / sum_prior_reshape return (i + 1, posterior, logit, center, masses)
def conv_capsule_mat_fast( input_tensor, input_activation, input_dim, output_dim, layer_name, num_routing=3, num_in_atoms=3, num_out_atoms=3, stride=2, kernel_size=5, min_var=0.0005, final_beta=1.0, ): """Convolutional Capsule layer with fast EM routing. Args: input_tensor: The input capsule features. input_activation: The input capsule activations. input_dim: Number of input capsule types. output_dim: Number of output capsule types. layer_name: Name of this layer, e.g. conv_capsule1 num_routing: Number of routing iterations. num_in_atoms: Number of features in each of the input capsules. num_out_atoms: Number of features in each of the output capsules. stride: Stride of the convolution. kernel_size: kernel size of the convolution. min_var: Minimum varience for each capsule to avoid NaNs. final_beta: beta for making the routing factors sharp. Returns: The final capsule center and activations. """ tf.logging.info('conv_capsule_mat %s', layer_name) tf.logging.info('input_shape %s', input_tensor.shape.as_list()) in_atom_sq = num_in_atoms * num_in_atoms with tf.variable_scope(layer_name): # This should be fully defined... # input_shape = tf.shape(input_tensor) input_shape = input_tensor.shape.as_list() batch, _, _, in_height, in_width = input_shape o_height = (in_height - kernel_size) // stride + 1 o_width = (in_width - kernel_size) // stride + 1 # This Variable will hold the state of the weights for the layer. kernel = utils.weight_variable(shape=[ input_dim, kernel_size, kernel_size, num_in_atoms, output_dim * num_out_atoms ], stddev=0.1) activation_biases = utils.bias_variable( [1, 1, output_dim, 1, 1, 1, 1, 1], init_value=0.2, name='activation_biases') sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1], init_value=.5, name='sigma_biases') with utils.maybe_jit_scope(), tf.name_scope('conv'): input_tensor_reshaped = tf.reshape( input_tensor, [batch * input_dim * in_atom_sq, in_height, in_width, 1]) input_act_reshaped = tf.reshape( input_activation, [batch * input_dim, in_height, in_width, 1]) conv_patches = utils.kernel_tile(input_tensor_reshaped, kernel_size, stride) act_patches = utils.kernel_tile(input_act_reshaped, kernel_size, stride) patches = tf.reshape(conv_patches, (batch, input_dim, in_atom_sq, o_height, o_width, kernel_size, kernel_size)) patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2]) patch_split = tf.reshape( patch_trans, (input_dim, kernel_size, kernel_size, batch * o_height * o_width * num_in_atoms, num_in_atoms), name='patch_split') a_patches = tf.reshape(act_patches, (batch, input_dim, 1, 1, o_height, o_width, kernel_size, kernel_size), name='a_patches') # Recompute Wx on backprop to save memory (perhaps redo patches as well?) # @tf.contrib.layers.recompute_grad def compute_wx(patch_split, kernel, is_recomputing=False): tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing) with utils.maybe_jit_scope(), tf.name_scope('wx'): wx = tf.matmul(patch_split, kernel) wx = tf.reshape( wx, (input_dim, kernel_size, kernel_size, batch, o_height, o_width, num_in_atoms * num_out_atoms, output_dim)) wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2]) return wx wx = compute_wx(patch_split, kernel.value()) with utils.maybe_jit_scope(): # Routing logit_shape = [ input_dim, output_dim, 1, o_height, o_width, kernel_size, kernel_size ] tf.logging.info('logit_shape: %s', logit_shape) activation, center = update_conv_routing_fast( wx=wx, input_activation=a_patches, activation_biases=activation_biases, sigma_biases=sigma_biases, logit_shape=logit_shape, num_out_atoms=num_out_atoms * num_out_atoms, input_dim=input_dim, num_routing=num_routing, output_dim=output_dim, min_var=min_var, final_beta=4 * final_beta, stride=stride, layer_name=layer_name, ) with utils.maybe_jit_scope(): out_activation = tf.squeeze(activation, axis=[1, 3, 6, 7], name='out_activation') out_center = tf.squeeze(center, axis=[1, 6, 7], name='out_center') out_activation = tf.sigmoid(out_activation) with tf.name_scope('center'): utils.activation_summary(out_center) return out_activation, out_center
def multi_gpu_model(features): """Build the Graph and train the model on multiple gpus.""" if FLAGS.use_caps: if FLAGS.use_em: inference = em_model.inference else: print('not supported') else: inference = simple_model.conv_inference with tf.device('/cpu:0'): global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) lr = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_steps, FLAGS.decay_rate, staircase=FLAGS.staircase) if FLAGS.clip_lr: lr = tf.maximum(lr, 1e-6) if FLAGS.adam: opt = tf.train.AdamOptimizer(lr) else: opt = tf.train.GradientDescentOptimizer(lr) tower_grads = [] corrects = [] almosts = [] result = {} with tf.variable_scope(tf.get_variable_scope()): for i in range(FLAGS.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % (i)) as scope: label_ = features[i]['labels'] y, result['recons_1'], result['recons_2'], result[ 'mid_act'] = inference(features[i]) result['logits'] = y losses, correct, almost = layers.optimizer( logits=y, labels=label_, multi=FLAGS.multi and FLAGS.data_set == 'mnist', scope=scope, softmax=FLAGS.softmax, rate=FLAGS.loss_rate, step=global_step, ) tf.get_variable_scope().reuse_variables() corrects.append(correct) almosts.append(almost) # summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) grads = opt.compute_gradients( losses, gate_gradients=tf.train.Optimizer.GATE_NONE, ) tower_grads.append(grads) with utils.maybe_jit_scope(), tf.name_scope('average_gradients'): grads = _average_gradients(tower_grads) summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) if FLAGS.verbose: for grad, var in grads: if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) summaries.append(tf.summary.scalar('learning_rate', lr)) result['summary'] = tf.summary.merge(summaries) result['train'] = opt.apply_gradients(grads, global_step=global_step) # result['train'] = y cors = tf.stack(corrects) alms = tf.stack(almosts) result['correct'] = tf.reduce_sum(cors, 0) result['almost'] = tf.reduce_sum(alms, 0) return result