def residual_block(x, in_filter, out_filter, stride, update_bn=True, activate_before_residual=False): """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv. Args: x: Tensor that is the output of the previous layer in the model. in_filter: Number of filters `x` has. out_filter: Number of filters that the output of this layer will have. stride: Integer that specified what stride should be applied to `x`. activate_before_residual: Boolean on whether a BN->ReLU should be applied to x before the convolution is applied. Returns: A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv and then adding that Tensor to `x`. """ if activate_before_residual: # Pass up RELU and BN activation for resnet with tf.variable_scope('shared_activation'): x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn') x = tf.nn.relu(x) orig_x = x else: orig_x = x block_x = x if not activate_before_residual: with tf.variable_scope('residual_only_activation'): block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='init_bn') block_x = tf.nn.relu(block_x) with tf.variable_scope('sub1'): block_x = ops.conv2d(block_x, out_filter, 3, stride=stride, scope='conv1') with tf.variable_scope('sub2'): block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2') block_x = tf.nn.relu(block_x) block_x = ops.conv2d(block_x, out_filter, 3, stride=1, scope='conv2') with tf.variable_scope( 'sub_add'): # If number of filters do not agree then zero pad them if in_filter != out_filter: orig_x = ops.avg_pool(orig_x, stride, stride) orig_x = ops.zero_pad(orig_x, in_filter, out_filter) x = orig_x + block_x return x
def _res_add(in_filter, out_filter, stride, x, orig_x): """Adds `x` with `orig_x`, both of which are layers in the model. Args: in_filter: Number of filters in `orig_x`. out_filter: Number of filters in `x`. stride: Integer specifying the stide that should be applied `orig_x`. x: Tensor that is the output of the previous layer. orig_x: Tensor that is the output of an earlier layer in the network. Returns: A Tensor that is the result of `x` and `orig_x` being added after zero padding and striding are applied to `orig_x` to get the shapes to match. """ if in_filter != out_filter: orig_x = ops.avg_pool(orig_x, stride, stride) orig_x = ops.zero_pad(orig_x, in_filter, out_filter) x = x + orig_x orig_x = x return x, orig_x