def compute_wx(patch_split, kernel, is_recomputing=False):
     tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing)
     with utils.maybe_jit_scope(), tf.name_scope('wx'):
         wx = tf.matmul(patch_split, kernel)
         wx = tf.reshape(
             wx, (input_dim, kernel_size, kernel_size, batch, o_height,
                  o_width, num_in_atoms * num_out_atoms, output_dim))
         wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2])
     return wx
        def _body(i, posterior, center, wx, activation_biases, sigma_biases,
                  input_activation, tile_filter):
            """Body of EM while loop."""
            tf.logging.info('  Wx: %s', wx)

            beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))

            posterior = tf.Print(posterior, [
                layer_name, i, h, ih,
                tf.reduce_min(posterior),
                tf.reduce_max(posterior)
            ],
                                 message='posterior')
            # route: [outdim, height?, width?, batch, indim]
            with tf.name_scope('vote_conf'):
                vote_conf = posterior * input_activation
                vote_conf = tf.maximum(vote_conf, 0.0)

            # masses: [batch, 1, outdim, 1, height, width, 1, 1]
            with tf.name_scope('masses'):
                masses = tf.reduce_sum(vote_conf,
                                       axis=[1, -1, -2],
                                       keepdims=True,
                                       name='masses_calculation') + 0.0000001
            with tf.name_scope('preactivate_unrolled'):
                preactivate_unrolled = vote_conf * wx

            # center: [batch, 1, outdim, outatom, height, width]
            with tf.name_scope('center'):
                center = .9 * tf.reduce_sum(
                    preactivate_unrolled, axis=[1, -1, -2],
                    keepdims=True) / masses + .1 * center

            # Rematerialization to save GPU memory. (+22ms/-1.6GB)
            # @tf.contrib.layers.recompute_grad
            def compute_noise_and_variance(wx, center, vote_conf, masses):
                noise = tf.squared_difference(wx, center)
                variance = min_var + tf.reduce_sum(
                    vote_conf * noise,
                    axis=[1, -1, -2],
                    keepdims=True,
                    name='variance_calculation') / masses
                return noise, variance

            with tf.name_scope('compute_noise_and_variance'):
                noise, variance = compute_noise_and_variance(
                    wx, center, vote_conf, masses)

            with tf.name_scope('win'):
                log_variance = tf.log(variance)
                p_i = -1 * tf.reduce_sum(log_variance, axis=3, keepdims=True)
                log_2pi = tf.log(2 * math.pi)
                sigma_b = tf.log(sigma_biases * sigma_biases + min_var)
                win = masses * (p_i - num_out_atoms *
                                (sigma_b + log_2pi + 1.0))
            with tf.name_scope('logit'):
                logit = beta * (win - activation_biases * 50 * num_out_atoms)
            with tf.name_scope('activation_update'):
                activation_update = tf.minimum(
                    0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit)))
            with tf.name_scope('sigma_update'):
                log_det_sigma = -1 * p_i
                sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0
            with tf.name_scope('exp_update'):
                exp_update = tf.reduce_sum(noise / (2 * variance),
                                           axis=3,
                                           keep_dims=True)
            prior_update = tf.subtract(activation_update - sigma_update,
                                       exp_update,
                                       name='prior_update_sub')
            max_prior_update = tf.reduce_max(prior_update,
                                             axis=[2, 3, 4, 5, 6, 7],
                                             keepdims=True,
                                             name='max_prior_opdate')
            prior_normal = tf.add(prior_update, -1 * max_prior_update)
            prior_exp = tf.exp(prior_normal)
            prior_exp_out = tf.reduce_sum(prior_exp,
                                          axis=2,
                                          keepdims=True,
                                          name='prior_exp_out')
            prior_exp_reshape = tf.reshape(prior_exp_out, [-1, h, h, k * k],
                                           name='prior_exp_reshape')

            sum_prior = tf.nn.conv2d_transpose(prior_exp_reshape,
                                               tile_filter,
                                               output_shape=[b * c, ih, ih, 1],
                                               strides=[1, s, s, 1],
                                               padding='VALID')
            sum_prior = tf.maximum(1e-6, sum_prior)

            sum_prior_patch = utils.kernel_tile(sum_prior,
                                                k,
                                                s,
                                                1,
                                                name='sum_prior_patch')

            with utils.maybe_jit_scope(), tf.name_scope('posterior'):
                sum_prior_reshape = tf.reshape(
                    sum_prior_patch, [-1, input_dim, 1, 1, h, h, k, k])
                posterior = prior_exp / sum_prior_reshape

            return (i + 1, posterior, logit, center, masses)
def conv_capsule_mat_fast(
    input_tensor,
    input_activation,
    input_dim,
    output_dim,
    layer_name,
    num_routing=3,
    num_in_atoms=3,
    num_out_atoms=3,
    stride=2,
    kernel_size=5,
    min_var=0.0005,
    final_beta=1.0,
):
    """Convolutional Capsule layer with fast EM routing.

  Args:
    input_tensor: The input capsule features.
    input_activation: The input capsule activations.
    input_dim: Number of input capsule types.
    output_dim: Number of output capsule types.
    layer_name: Name of this layer, e.g. conv_capsule1
    num_routing: Number of routing iterations.
    num_in_atoms: Number of features in each of the input capsules.
    num_out_atoms: Number of features in each of the output capsules.
    stride: Stride of the convolution.
    kernel_size: kernel size of the convolution.
    min_var: Minimum varience for each capsule to avoid NaNs.
    final_beta: beta for making the routing factors sharp.

  Returns:
    The final capsule center and activations.
  """
    tf.logging.info('conv_capsule_mat %s', layer_name)
    tf.logging.info('input_shape %s', input_tensor.shape.as_list())
    in_atom_sq = num_in_atoms * num_in_atoms
    with tf.variable_scope(layer_name):
        # This should be fully defined...
        # input_shape = tf.shape(input_tensor)
        input_shape = input_tensor.shape.as_list()
        batch, _, _, in_height, in_width = input_shape
        o_height = (in_height - kernel_size) // stride + 1
        o_width = (in_width - kernel_size) // stride + 1

        # This Variable will hold the state of the weights for the layer.
        kernel = utils.weight_variable(shape=[
            input_dim, kernel_size, kernel_size, num_in_atoms,
            output_dim * num_out_atoms
        ],
                                       stddev=0.1)
        activation_biases = utils.bias_variable(
            [1, 1, output_dim, 1, 1, 1, 1, 1],
            init_value=0.2,
            name='activation_biases')
        sigma_biases = utils.bias_variable([1, 1, output_dim, 1, 1, 1, 1, 1],
                                           init_value=.5,
                                           name='sigma_biases')

        with utils.maybe_jit_scope(), tf.name_scope('conv'):
            input_tensor_reshaped = tf.reshape(
                input_tensor,
                [batch * input_dim * in_atom_sq, in_height, in_width, 1])
            input_act_reshaped = tf.reshape(
                input_activation, [batch * input_dim, in_height, in_width, 1])

            conv_patches = utils.kernel_tile(input_tensor_reshaped,
                                             kernel_size, stride)
            act_patches = utils.kernel_tile(input_act_reshaped, kernel_size,
                                            stride)

            patches = tf.reshape(conv_patches,
                                 (batch, input_dim, in_atom_sq, o_height,
                                  o_width, kernel_size, kernel_size))
            patch_trans = tf.transpose(patches, [1, 5, 6, 0, 3, 4, 2])
            patch_split = tf.reshape(
                patch_trans,
                (input_dim, kernel_size, kernel_size,
                 batch * o_height * o_width * num_in_atoms, num_in_atoms),
                name='patch_split')
            a_patches = tf.reshape(act_patches,
                                   (batch, input_dim, 1, 1, o_height, o_width,
                                    kernel_size, kernel_size),
                                   name='a_patches')

        # Recompute Wx on backprop to save memory (perhaps redo patches as well?)
        # @tf.contrib.layers.recompute_grad
        def compute_wx(patch_split, kernel, is_recomputing=False):
            tf.logging.info('compute_wx(is_recomputing=%s)', is_recomputing)
            with utils.maybe_jit_scope(), tf.name_scope('wx'):
                wx = tf.matmul(patch_split, kernel)
                wx = tf.reshape(
                    wx, (input_dim, kernel_size, kernel_size, batch, o_height,
                         o_width, num_in_atoms * num_out_atoms, output_dim))
                wx = tf.transpose(wx, [3, 0, 7, 6, 4, 5, 1, 2])
            return wx

        wx = compute_wx(patch_split, kernel.value())

        with utils.maybe_jit_scope():
            # Routing
            logit_shape = [
                input_dim, output_dim, 1, o_height, o_width, kernel_size,
                kernel_size
            ]
            tf.logging.info('logit_shape: %s', logit_shape)
            activation, center = update_conv_routing_fast(
                wx=wx,
                input_activation=a_patches,
                activation_biases=activation_biases,
                sigma_biases=sigma_biases,
                logit_shape=logit_shape,
                num_out_atoms=num_out_atoms * num_out_atoms,
                input_dim=input_dim,
                num_routing=num_routing,
                output_dim=output_dim,
                min_var=min_var,
                final_beta=4 * final_beta,
                stride=stride,
                layer_name=layer_name,
            )

        with utils.maybe_jit_scope():
            out_activation = tf.squeeze(activation,
                                        axis=[1, 3, 6, 7],
                                        name='out_activation')
            out_center = tf.squeeze(center, axis=[1, 6, 7], name='out_center')
            out_activation = tf.sigmoid(out_activation)

        with tf.name_scope('center'):
            utils.activation_summary(out_center)

        return out_activation, out_center
Beispiel #4
0
def multi_gpu_model(features):
    """Build the Graph and train the model on multiple gpus."""
    if FLAGS.use_caps:
        if FLAGS.use_em:
            inference = em_model.inference
        else:
            print('not supported')
    else:
        inference = simple_model.conv_inference
    with tf.device('/cpu:0'):
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        lr = tf.train.exponential_decay(FLAGS.learning_rate,
                                        global_step,
                                        FLAGS.decay_steps,
                                        FLAGS.decay_rate,
                                        staircase=FLAGS.staircase)
        if FLAGS.clip_lr:
            lr = tf.maximum(lr, 1e-6)

        if FLAGS.adam:
            opt = tf.train.AdamOptimizer(lr)
        else:
            opt = tf.train.GradientDescentOptimizer(lr)

        tower_grads = []
        corrects = []
        almosts = []
        result = {}
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(FLAGS.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('tower_%d' % (i)) as scope:
                        label_ = features[i]['labels']
                        y, result['recons_1'], result['recons_2'], result[
                            'mid_act'] = inference(features[i])
                        result['logits'] = y

                        losses, correct, almost = layers.optimizer(
                            logits=y,
                            labels=label_,
                            multi=FLAGS.multi and FLAGS.data_set == 'mnist',
                            scope=scope,
                            softmax=FLAGS.softmax,
                            rate=FLAGS.loss_rate,
                            step=global_step,
                        )
                        tf.get_variable_scope().reuse_variables()
                        corrects.append(correct)
                        almosts.append(almost)
                        #           summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                        grads = opt.compute_gradients(
                            losses,
                            gate_gradients=tf.train.Optimizer.GATE_NONE,
                        )
                        tower_grads.append(grads)

        with utils.maybe_jit_scope(), tf.name_scope('average_gradients'):
            grads = _average_gradients(tower_grads)
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        if FLAGS.verbose:
            for grad, var in grads:
                if grad is not None:
                    summaries.append(
                        tf.summary.histogram(var.op.name + '/gradients', grad))
        summaries.append(tf.summary.scalar('learning_rate', lr))
        result['summary'] = tf.summary.merge(summaries)
        result['train'] = opt.apply_gradients(grads, global_step=global_step)
        # result['train'] = y

        cors = tf.stack(corrects)
        alms = tf.stack(almosts)
        result['correct'] = tf.reduce_sum(cors, 0)
        result['almost'] = tf.reduce_sum(alms, 0)

        return result