コード例 #1
0
ファイル: wrn.py プロジェクト: zhang405744522/ssl_detection
def residual_block(x,
                   in_filter,
                   out_filter,
                   stride,
                   update_bn=True,
                   activate_before_residual=False):
    """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.

    Args:
        x: Tensor that is the output of the previous layer in the model.
        in_filter: Number of filters `x` has.
        out_filter: Number of filters that the output of this layer will have.
        stride: Integer that specified what stride should be applied to `x`.
        activate_before_residual: Boolean on whether a BN->ReLU should be
          applied to x before the convolution is applied.

    Returns:
        A Tensor that is the result of applying two sequences of BN->ReLU->3x3
        Conv
        and then adding that Tensor to `x`.
    """

    if activate_before_residual:  # Pass up RELU and BN activation for resnet
        with tf.variable_scope('shared_activation'):
            x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn')
            x = tf.nn.relu(x)
            orig_x = x
    else:
        orig_x = x

    block_x = x
    if not activate_before_residual:
        with tf.variable_scope('residual_only_activation'):
            block_x = ops.batch_norm(block_x,
                                     update_stats=update_bn,
                                     scope='init_bn')
            block_x = tf.nn.relu(block_x)

    with tf.variable_scope('sub1'):
        block_x = ops.conv2d(block_x,
                             out_filter,
                             3,
                             stride=stride,
                             scope='conv1')

    with tf.variable_scope('sub2'):
        block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2')
        block_x = tf.nn.relu(block_x)
        block_x = ops.conv2d(block_x, out_filter, 3, stride=1, scope='conv2')

    with tf.variable_scope(
            'sub_add'):  # If number of filters do not agree then zero pad them
        if in_filter != out_filter:
            orig_x = ops.avg_pool(orig_x, stride, stride)
            orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
    x = orig_x + block_x
    return x
コード例 #2
0
def shortcut(x, output_filters, stride):
    """Applies strided avg pool or zero padding to make output_filters match x."""
    num_filters = int(x.shape[3])
    if stride == 2:
        x = ops.avg_pool(x, 2, stride=stride, padding='SAME')
    if num_filters != output_filters:
        diff = output_filters - num_filters
        assert diff > 0
        # Zero padd diff zeros
        padding = [[0, 0], [0, 0], [0, 0], [0, diff]]
        x = tf.pad(x, padding)
    return x
コード例 #3
0
ファイル: wrn.py プロジェクト: JMFlin/auto-preference-finder
def _res_add(in_filter, out_filter, stride, x, orig_x):
    """Adds `x` with `orig_x`, both of which are layers in the model.

    Args:
        in_filter: Number of filters in `orig_x`.
        out_filter: Number of filters in `x`.
        stride: Integer specifying the stide that should be applied `orig_x`.
        x: Tensor that is the output of the previous layer.
        orig_x: Tensor that is the output of an earlier layer in the network.

    Returns:
        A Tensor that is the result of `x` and `orig_x` being added after
        zero padding and striding are applied to `orig_x` to get the shapes
        to match.
    """
    if in_filter != out_filter:
        orig_x = ops.avg_pool(orig_x, stride, stride)
        orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
    x = x + orig_x
    orig_x = x
    return x, orig_x