Example #1
0
def residual_block(x,
                   in_filter,
                   out_filter,
                   stride,
                   activate_before_residual=False):
    """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.

  Args:
    x: Tensor that is the output of the previous layer in the model.
    in_filter: Number of filters `x` has.
    out_filter: Number of filters that the output of this layer will have.
    stride: Integer that specified what stride should be applied to `x`.
    activate_before_residual: Boolean on whether a BN->ReLU should be applied to
      x before the convolution is applied.

  Returns:
    A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
    and then adding that Tensor to `x`.
  """

    if activate_before_residual:  # Pass up RELU and BN activation for resnet
        with tf.variable_scope('shared_activation'):
            x = ops.batch_norm(x, scope='init_bn')
            x = tf.nn.relu(x)
            orig_x = x
    else:
        orig_x = x

    block_x = x
    if not activate_before_residual:
        with tf.variable_scope('residual_only_activation'):
            block_x = ops.batch_norm(block_x, scope='init_bn')
            block_x = tf.nn.relu(block_x)

    with tf.variable_scope('sub1'):
        block_x = ops.conv2d(block_x,
                             out_filter,
                             3,
                             stride=stride,
                             scope='conv1')

    with tf.variable_scope('sub2'):
        block_x = ops.batch_norm(block_x, scope='bn2')
        block_x = tf.nn.relu(block_x)
        block_x = ops.conv2d(block_x, out_filter, 3, stride=1, scope='conv2')

    with tf.variable_scope(
            'sub_add'):  # If number of filters do not agree then zero pad them
        if in_filter != out_filter:
            orig_x = ops.avg_pool(orig_x, stride, stride)
            orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
    x = orig_x + block_x
    return x
Example #2
0
    def __call__(self,
                 images,
                 name,
                 reuse=True,
                 training=True,
                 custom_getter=None):
        """Builds the WRN model.

    Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.

    Args:
      images: Tensor of images that will be fed into the Wide ResNet Model.
      name: Name of the model as scope
      reuse: If True, reuses the parameters.
      training: If True, for training stage.
      custom_getter: custom_getter function for variable_scope.

    Returns:
      The logits of the Wide ResNet model.
    """
        num_classes = self.num_classes
        wrn_size = self.wrn_size

        kernel_size = wrn_size
        filter_size = 3
        num_blocks_per_resnet = 4
        filters = [
            min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
        ]
        strides = [1, 2, 2]  # stride for each resblock

        # scopes = setup_arg_scopes(training)
        # with contextlib.nested(*scopes):
        with setup_arg_scopes(training):
            with tf.variable_scope(name,
                                   reuse=reuse,
                                   custom_getter=custom_getter):

                # Run the first conv
                with tf.variable_scope('init'):
                    x = images
                    output_filters = filters[0]
                    x = ops.conv2d(x,
                                   output_filters,
                                   filter_size,
                                   scope='init_conv')

                first_x = x  # Res from the beginning
                orig_x = x  # Res from previous block

                for block_num in range(1, 4):
                    with tf.variable_scope('unit_{}_0'.format(block_num)):
                        activate_before_residual = True if block_num == 1 else False
                        x = residual_block(
                            x,
                            filters[block_num - 1],
                            filters[block_num],
                            strides[block_num - 1],
                            activate_before_residual=activate_before_residual)
                    for i in range(1, num_blocks_per_resnet):
                        with tf.variable_scope('unit_{}_{}'.format(
                                block_num, i)):
                            x = residual_block(x,
                                               filters[block_num],
                                               filters[block_num],
                                               1,
                                               activate_before_residual=False)
                    x, orig_x = _res_add(filters[block_num - 1],
                                         filters[block_num],
                                         strides[block_num - 1], x, orig_x)
                final_stride_val = np.prod(strides)
                x, _ = _res_add(filters[0], filters[3], final_stride_val, x,
                                first_x)
                with tf.variable_scope('unit_last'):
                    x = ops.batch_norm(x, scope='final_bn')
                    x = tf.nn.relu(x)
                    x = ops.global_avg_pool(x)
                    logits = ops.fc(x, num_classes)

                if not isinstance(reuse, bool) or not reuse:
                    self.regularization_loss = decay_weights(
                        self.wd, utils.get_var(tf.trainable_variables(), name))
                    self.init(name, with_name='moving', outputs=logits)
                    self.count_parameters(name)
        return logits