def rgb_conv_stem(inputs, num_frames, filters, temporal_dilation, bn_decay: float = rf.BATCH_NORM_DECAY, bn_epsilon: float = rf.BATCH_NORM_EPSILON, use_sync_bn: bool = False): """Layers for a RGB stem. Args: inputs: A `Tensor` of size `[batch*time, height, width, channels]`. num_frames: `int` number of frames in the input tensor. filters: `int` number of filters in the convolution. temporal_dilation: `int` temporal dilatioin size for the 1D conv. bn_decay: `float` batch norm decay parameter to use. bn_epsilon: `float` batch norm epsilon parameter to use. use_sync_bn: use synchronized batch norm for TPU. Returns: The output `Tensor`. """ data_format = tf.keras.backend.image_data_format() assert data_format == 'channels_last' if temporal_dilation < 1: temporal_dilation = 1 inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=7, strides=2) inputs = tf.identity(inputs, 'initial_conv') inputs = rf.build_batch_norm( bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( inputs) inputs = tf.nn.relu(inputs) inputs = reshape_temporal_conv1d_bn( inputs=inputs, filters=filters, kernel_size=5, num_frames=num_frames, temporal_dilation=temporal_dilation, bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn) inputs = tf.keras.layers.MaxPool2D( pool_size=3, strides=2, padding='SAME')( inputs=inputs) inputs = tf.identity(inputs, 'initial_max_pool') return inputs
def flow_conv_stem(inputs, filters, temporal_dilation, bn_decay: float = rf.BATCH_NORM_DECAY, bn_epsilon: float = rf.BATCH_NORM_EPSILON, use_sync_bn: bool = False): """Layers for an optical flow stem. Args: inputs: A `Tensor` of size `[batch*time, height, width, channels]`. filters: `int` number of filters in the convolution. temporal_dilation: `int` temporal dilatioin size for the 1D conv. bn_decay: `float` batch norm decay parameter to use. bn_epsilon: `float` batch norm epsilon parameter to use. use_sync_bn: use synchronized batch norm for TPU. Returns: The output `Tensor`. """ if temporal_dilation < 1: temporal_dilation = 1 inputs = conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=7, strides=2) inputs = tf.identity(inputs, 'initial_conv') inputs = rf.build_batch_norm(bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)(inputs) inputs = tf.nn.relu(inputs) inputs = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='SAME')(inputs=inputs) inputs = tf.identity(inputs, 'initial_max_pool') return inputs
def reshape_temporal_conv1d_bn(inputs: tf.Tensor, filters: int, kernel_size: int, num_frames: int = 32, temporal_dilation: int = 1, bn_decay: float = rf.BATCH_NORM_DECAY, bn_epsilon: float = rf.BATCH_NORM_EPSILON, use_sync_bn: bool = False): """Performs 1D temporal conv. followed by batch normalization with reshaping. Args: inputs: `Tensor` of size `[batch*time, height, width, channels]`. Only supports 'channels_last' as the data format. filters: `int` number of filters in the convolution. kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d` operations. Should be a positive integer. num_frames: `int` number of frames in the input tensor. temporal_dilation: `int` temporal dilatioin size for the 1D conv. bn_decay: `float` batch norm decay parameter to use. bn_epsilon: `float` batch norm epsilon parameter to use. use_sync_bn: use synchronized batch norm for TPU. Returns: A padded `Tensor` of the same `data_format` with size either intact (if `kernel_size == 1`) or padded (if `kernel_size > 1`). """ data_format = tf.keras.backend.image_data_format() assert data_format == 'channels_last' feature_shape = inputs.shape inputs = tf.reshape( inputs, [-1, num_frames, feature_shape[1] * feature_shape[2], feature_shape[3]]) if temporal_dilation == 1: inputs = tf.keras.layers.Conv2D( filters=filters, kernel_size=(kernel_size, 1), strides=1, padding='SAME', use_bias=False, kernel_initializer=tf.keras.initializers.VarianceScaling())( inputs=inputs) else: inputs = tf.keras.layers.Conv2D( filters=filters, kernel_size=(kernel_size, 1), strides=1, padding='SAME', dilation_rate=(temporal_dilation, 1), use_bias=False, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=math.sqrt(2.0 / (kernel_size * feature_shape[3]))))( inputs=inputs) num_channel = inputs.shape[3] inputs = tf.reshape(inputs, [-1, feature_shape[1], feature_shape[2], num_channel]) inputs = rf.build_batch_norm( bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( inputs) inputs = tf.nn.relu(inputs) return inputs
def bottleneck_block_interleave(inputs: tf.Tensor, filters: int, inter_filters: int, strides: int, use_projection: bool = False, num_frames: int = 32, temporal_dilation: int = 1, bn_decay: float = rf.BATCH_NORM_DECAY, bn_epsilon: float = rf.BATCH_NORM_EPSILON, use_sync_bn: bool = False, step=1): """Interleaves a standard 2D residual module and (2+1)D residual module. Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: `Tensor` of size `[batch*time, channels, height, width]`. filters: `int` number of filters for the first conv. layer. The last conv. layer will use 4 times as many filters. inter_filters: `int` number of filters for the second conv. layer. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input spatially. use_projection: `bool` for whether this block should use a projection shortcut (versus the default identity shortcut). This is usually `True` for the first block of a block group, which may change the number of filters and the resolution. num_frames: `int` number of frames in the input tensor. temporal_dilation: `int` temporal dilatioin size for the 1D conv. bn_decay: `float` batch norm decay parameter to use. bn_epsilon: `float` batch norm epsilon parameter to use. use_sync_bn: use synchronized batch norm for TPU. step: `int` to decide whether to put 2D module or (2+1)D module. Returns: The output `Tensor` of the block. """ if strides > 1 and not use_projection: raise ValueError('strides > 1 requires use_projections=True, otherwise the ' 'inputs and shortcut will have shape mismatch') shortcut = inputs if use_projection: # Projection shortcut only in first block within a group. Bottleneck blocks # end with 4 times the number of filters. filters_out = 4 * filters shortcut = conv2d_fixed_padding( inputs=inputs, filters=filters_out, kernel_size=1, strides=strides) shortcut = rf.build_batch_norm( bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( shortcut) if step % 2 == 1: k = 3 inputs = reshape_temporal_conv1d_bn( inputs=inputs, filters=filters, kernel_size=k, num_frames=num_frames, temporal_dilation=temporal_dilation, bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn) else: inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=1, strides=1) inputs = rf.build_batch_norm( bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( inputs) inputs = tf.nn.relu(inputs) inputs = conv2d_fixed_padding( inputs=inputs, filters=inter_filters, kernel_size=3, strides=strides) inputs = rf.build_batch_norm( bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( inputs) inputs = tf.nn.relu(inputs) inputs = conv2d_fixed_padding( inputs=inputs, filters=4 * filters, kernel_size=1, strides=1) inputs = rf.build_batch_norm( init_zero=True, bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn)( inputs) return tf.nn.relu(inputs + shortcut)