Example #1
0
def residual_block(x,
                   in_filter,
                   out_filter,
                   stride,
                   update_bn=True,
                   activate_before_residual=False):
    """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.

    Args:
      x: Tensor that is the output of the previous layer in the model.
      in_filter: Number of filters `x` has.
      out_filter: Number of filters that the output of this layer will have.
      stride: Integer that specified what stride should be applied to `x`.
      activate_before_residual: Boolean on whether a BN->ReLU should be applied
        to x before the convolution is applied.

    Returns:
      A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
      and then adding that Tensor to `x`.
    """

    if activate_before_residual:  # Pass up RELU and BN activation for resnet
        with tf.variable_scope('shared_activation'):
            x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn')
            x = tf.nn.relu(x)
            orig_x = x
    else:
        orig_x = x

    block_x = x
    if not activate_before_residual:
        with tf.variable_scope('residual_only_activation'):
            block_x = ops.batch_norm(block_x,
                                     update_stats=update_bn,
                                     scope='init_bn')
            block_x = tf.nn.relu(block_x)

    with tf.variable_scope('sub1'):
        block_x = ops.conv2d(block_x,
                             out_filter,
                             3,
                             stride=stride,
                             scope='conv1')

    with tf.variable_scope('sub2'):
        block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2')
        block_x = tf.nn.relu(block_x)
        block_x = ops.conv2d(block_x, out_filter, 3, stride=1, scope='conv2')

    with tf.variable_scope(
            'sub_add'):  # If number of filters do not agree then zero pad them
        if in_filter != out_filter:
            orig_x = ops.avg_pool(orig_x, stride, stride)
            orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
    x = orig_x + block_x
    return x
Example #2
0
def _shake_shake_skip_connection(x, output_filters, stride):
  """Adds a residual connection to the filter x for the shake-shake model."""
  curr_filters = int(x.shape[3])
  if curr_filters == output_filters:
    return x
  stride_spec = ops.stride_arr(stride, stride)
  # Skip path 1
  path1 = tf.nn.avg_pool(
      x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
  path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')

  # Skip path 2
  # First pad with 0's then crop
  pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
  path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :]
  concat_axis = 3

  path2 = tf.nn.avg_pool(
      path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
  path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')

  # Concat and apply BN
  final_path = tf.concat(values=[path1, path2], axis=concat_axis)
  final_path = ops.batch_norm(final_path, scope='final_path_bn')
  return final_path
Example #3
0
def build_shake_shake_model(images, num_classes, hparams, is_training):
  """Builds the Shake-Shake model.

  Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    hparams: tf.HParams object that contains additional hparams needed to
      construct the model. In this case it is the `shake_shake_widen_factor`
      that is used to determine how many filters the model has.
    is_training: Is the model training or not.

  Returns:
    The logits of the Shake-Shake model.
  """
  depth = 26
  k = hparams.shake_shake_widen_factor  # The widen factor
  n = int((depth - 2) / 6)
  x = images

  x = ops.conv2d(x, 16, 3, scope='init_conv')
  x = ops.batch_norm(x, scope='init_bn')
  with tf.variable_scope('L1'):
    x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
  with tf.variable_scope('L2'):
    x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
  with tf.variable_scope('L3'):
    x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
  x = tf.nn.relu(x)
  x = ops.global_avg_pool(x)

  # Fully connected
  logits = ops.fc(x, num_classes)
  return logits
Example #4
0
def build_wrn_model(images, num_classes, wrn_size):
    """Builds the WRN model.

  Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    wrn_size: Parameter that scales the number of filters in the Wide ResNet
      model.

  Returns:
    The logits of the Wide ResNet model.
  """
    kernel_size = wrn_size
    filter_size = 3
    num_blocks_per_resnet = 4
    filters = [
        min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
    ]
    strides = [1, 2, 2]  # stride for each resblock

    # Run the first conv
    with tf.variable_scope('init'):
        x = images
        output_filters = filters[0]
        x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')

    first_x = x  # Res from the beginning
    orig_x = x  # Res from previous block

    for block_num in range(1, 4):
        with tf.variable_scope('unit_{}_0'.format(block_num)):
            activate_before_residual = True if block_num == 1 else False
            x = residual_block(
                x,
                filters[block_num - 1],
                filters[block_num],
                strides[block_num - 1],
                activate_before_residual=activate_before_residual)
        for i in range(1, num_blocks_per_resnet):
            with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
                x = residual_block(x,
                                   filters[block_num],
                                   filters[block_num],
                                   1,
                                   activate_before_residual=False)
        x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
                             strides[block_num - 1], x, orig_x)
    final_stride_val = np.prod(strides)
    x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
    with tf.variable_scope('unit_last'):
        x = ops.batch_norm(x, scope='final_bn')
        x = tf.nn.relu(x)
        x = ops.global_avg_pool(x)
        logits = ops.fc(x, num_classes)
    return logits