def build(self, input_shape): if self._add_pos_embed: self._pos_embed = AddPositionEmbs( posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02), name='posembed_input') self._dropout = layers.Dropout(rate=self._dropout_rate) self._encoder_layers = [] # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html for i in range(self._num_layers): encoder_layer = nn_blocks.TransformerEncoderBlock( inner_activation=activations.gelu, num_attention_heads=self._num_heads, inner_dim=self._mlp_dim, output_dropout=self._dropout_rate, attention_dropout=self._attention_dropout_rate, kernel_regularizer=self._kernel_regularizer, kernel_initializer=self._kernel_initializer, norm_first=True, stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 1, self._num_layers), norm_epsilon=1e-6) self._encoder_layers.append(encoder_layer) self._norm = layers.LayerNormalization(epsilon=1e-6) super().build(input_shape)
def _build_scale_permuted_network(self, net, input_width, weighted_fusion=False): """Builds scale-permuted network.""" net_sizes = [int(math.ceil(input_width / 2**2))] * len(net) net_block_fns = [self._init_block_fn] * len(net) num_outgoing_connections = [0] * len(net) endpoints = {} for i, block_spec in enumerate(self._block_specs): # Find out specs for the target block. target_width = int(math.ceil(input_width / 2**block_spec.level)) target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] * self._filter_size_scale) target_block_fn = block_spec.block_fn # Resample then merge input0 and input1. parents = [] input0 = block_spec.input_offsets[0] input1 = block_spec.input_offsets[1] x0 = self._resample_with_alpha( inputs=net[input0], input_width=net_sizes[input0], input_block_fn=net_block_fns[input0], target_width=target_width, target_num_filters=target_num_filters, target_block_fn=target_block_fn, alpha=self._resample_alpha) parents.append(x0) num_outgoing_connections[input0] += 1 x1 = self._resample_with_alpha( inputs=net[input1], input_width=net_sizes[input1], input_block_fn=net_block_fns[input1], target_width=target_width, target_num_filters=target_num_filters, target_block_fn=target_block_fn, alpha=self._resample_alpha) parents.append(x1) num_outgoing_connections[input1] += 1 # Merge 0 outdegree blocks to the output block. if block_spec.is_output: for j, (j_feat, j_connections) in enumerate( zip(net, num_outgoing_connections)): if j_connections == 0 and (j_feat.shape[2] == target_width and j_feat.shape[3] == x0.shape[3]): parents.append(j_feat) num_outgoing_connections[j] += 1 # pylint: disable=g-direct-tensorflow-import if weighted_fusion: dtype = parents[0].dtype parent_weights = [ tf.nn.relu( tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format( i, j)), dtype=dtype)) for j in range(len(parents)) ] weights_sum = tf.add_n(parent_weights) parents = [ parents[i] * parent_weights[i] / (weights_sum + 0.0001) for i in range(len(parents)) ] # Fuse all parent nodes then build a new block. x = tf_utils.get_activation(self._activation_fn)(tf.add_n(parents)) x = self._block_group( inputs=x, filters=target_num_filters, strides=1, block_fn_cand=target_block_fn, block_repeats=self._block_repeats, stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 1, len(self._block_specs)), name='scale_permuted_block_{}'.format(i + 1)) net.append(x) net_sizes.append(target_width) net_block_fns.append(target_block_fn) num_outgoing_connections.append(0) # Save output feats. if block_spec.is_output: if block_spec.level in endpoints: raise ValueError( 'Duplicate feats found for output level {}.'.format( block_spec.level)) if (block_spec.level < self._min_level or block_spec.level > self._max_level): logging.warning( 'SpineNet output level out of range [min_level, max_level] = ' '[%s, %s] will not be used for further processing.', self._min_level, self._max_level) endpoints[str(block_spec.level)] = x return endpoints
def __init__( self, model_id: int, temporal_strides: List[int], temporal_kernel_sizes: List[Tuple[int]], use_self_gating: List[int] = None, input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), stem_type='v0', stem_conv_temporal_kernel_size=5, stem_conv_temporal_stride=2, stem_pool_temporal_stride=2, init_stochastic_depth_rate=0.0, activation='relu', se_ratio=None, use_sync_bn=False, norm_momentum=0.99, norm_epsilon=0.001, kernel_initializer='VarianceScaling', kernel_regularizer=None, bias_regularizer=None, **kwargs): """Initializes a 3D ResNet model. Args: model_id: An `int` of depth of ResNet backbone model. temporal_strides: A list of integers that specifies the temporal strides for all 3d blocks. temporal_kernel_sizes: A list of tuples that specifies the temporal kernel sizes for all 3d blocks in different block groups. use_self_gating: A list of booleans to specify applying self-gating module or not in each block group. If None, self-gating is not applied. input_specs: A `tf.keras.layers.InputSpec` of the input tensor. stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187). stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the first conv layer. stem_conv_temporal_stride: An `int` of temporal stride for the first conv layer. stem_pool_temporal_stride: An `int` of temporal stride for the first pool layer. init_stochastic_depth_rate: A `float` of initial stochastic depth rate. activation: A `str` of name of the activation function. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. use_sync_bn: If True, use synchronized batch normalization. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. kernel_initializer: A str for kernel initializer of convolutional layers. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. **kwargs: Additional keyword arguments to be passed. """ self._model_id = model_id self._temporal_strides = temporal_strides self._temporal_kernel_sizes = temporal_kernel_sizes self._input_specs = input_specs self._stem_type = stem_type self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size self._stem_conv_temporal_stride = stem_conv_temporal_stride self._stem_pool_temporal_stride = stem_pool_temporal_stride self._use_self_gating = use_self_gating self._se_ratio = se_ratio self._init_stochastic_depth_rate = init_stochastic_depth_rate self._use_sync_bn = use_sync_bn self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon if use_sync_bn: self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 # Build ResNet3D backbone. inputs = tf.keras.Input(shape=input_specs.shape[1:]) # Build stem. if stem_type == 'v0': x = layers.Conv3D( filters=64, kernel_size=[stem_conv_temporal_kernel_size, 7, 7], strides=[stem_conv_temporal_stride, 2, 2], use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) elif stem_type == 'v1': x = layers.Conv3D( filters=32, kernel_size=[stem_conv_temporal_kernel_size, 3, 3], strides=[stem_conv_temporal_stride, 2, 2], use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv3D(filters=32, kernel_size=[1, 3, 3], strides=[1, 1, 1], use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv3D(filters=64, kernel_size=[1, 3, 3], strides=[1, 1, 1], use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) else: raise ValueError(f'Stem type {stem_type} not supported.') temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3 x = layers.MaxPool3D(pool_size=[temporal_kernel_size, 3, 3], strides=[stem_pool_temporal_stride, 2, 2], padding='same')(x) # Build intermediate blocks and endpoints. resnet_specs = RESNET_SPECS[model_id] if len(temporal_strides) != len(resnet_specs) or len( temporal_kernel_sizes) != len(resnet_specs): raise ValueError( 'Number of blocks in temporal specs should equal to resnet_specs.' ) endpoints = {} for i, resnet_spec in enumerate(resnet_specs): if resnet_spec[0] == 'bottleneck3d': block_fn = nn_blocks_3d.BottleneckBlock3D else: raise ValueError('Block fn `{}` is not supported.'.format( resnet_spec[0])) x = self._block_group( inputs=x, filters=resnet_spec[1], temporal_kernel_sizes=temporal_kernel_sizes[i], temporal_strides=temporal_strides[i], spatial_strides=(1 if i == 0 else 2), block_fn=block_fn, block_repeats=resnet_spec[2], stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 2, 5), use_self_gating=use_self_gating[i] if use_self_gating else False, name='block_group_l{}'.format(i + 2)) endpoints[str(i + 2)] = x self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def __init__(self, model_id, input_specs=layers.InputSpec(shape=[None, None, None, 3]), depth_multiplier=1.0, stem_type='v0', se_ratio=None, init_stochastic_depth_rate=0.0, activation='relu', use_sync_bn=False, norm_momentum=0.99, norm_epsilon=0.001, kernel_initializer='VarianceScaling', kernel_regularizer=None, bias_regularizer=None, **kwargs): """ResNet initialization function. Args: model_id: `int` depth of ResNet backbone model. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. depth_multiplier: `float` a depth multiplier to uniformaly scale up all layers in channel size in ResNet. stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`, use ResNet-C type stem (https://arxiv.org/abs/1812.01187). se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer. init_stochastic_depth_rate: `float` initial stochastic depth rate. activation: `str` name of the activation function. use_sync_bn: if True, use synchronized batch normalization. norm_momentum: `float` normalization omentum for the moving average. norm_epsilon: `float` small float added to variance to avoid dividing by zero. kernel_initializer: kernel_initializer for convolutional layers. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. Default to None. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Default to None. **kwargs: keyword arguments to be passed. """ self._model_id = model_id self._input_specs = input_specs self._depth_multiplier = depth_multiplier self._stem_type = stem_type self._se_ratio = se_ratio self._init_stochastic_depth_rate = init_stochastic_depth_rate self._use_sync_bn = use_sync_bn self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon if use_sync_bn: self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 # Build ResNet. inputs = tf.keras.Input(shape=input_specs.shape[1:]) if stem_type == 'v0': x = layers.Conv2D(filters=int(64 * self._depth_multiplier), kernel_size=7, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) elif stem_type == 'v1': x = layers.Conv2D(filters=int(32 * self._depth_multiplier), kernel_size=3, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D(filters=int(32 * self._depth_multiplier), kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D(filters=int(64 * self._depth_multiplier), kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) else: raise ValueError('Stem type {} not supported.'.format(stem_type)) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) endpoints = {} for i, spec in enumerate(RESNET_SPECS[model_id]): if spec[0] == 'residual': block_fn = nn_blocks.ResidualBlock elif spec[0] == 'bottleneck': block_fn = nn_blocks.BottleneckBlock else: raise ValueError('Block fn `{}` is not supported.'.format( spec[0])) x = self._block_group( inputs=x, filters=int(spec[1] * self._depth_multiplier), strides=(1 if i == 0 else 2), block_fn=block_fn, block_repeats=spec[2], stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 2, 5), name='block_group_l{}'.format(i + 2)) endpoints[str(i + 2)] = x self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def __init__(self, model_id, output_stride, input_specs=layers.InputSpec(shape=[None, None, None, 3]), stem_type='v0', se_ratio=None, init_stochastic_depth_rate=0.0, multigrid=None, last_stage_repeats=1, activation='relu', use_sync_bn=False, norm_momentum=0.99, norm_epsilon=0.001, kernel_initializer='VarianceScaling', kernel_regularizer=None, bias_regularizer=None, **kwargs): """Initializes a ResNet model with DeepLab modification. Args: <<<<<<< HEAD model_id: `int` depth of ResNet backbone model. output_stride: `int` output stride, ratio of input to output resolution. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3 convs. se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer. init_stochastic_depth_rate: `float` initial stochastic depth rate. multigrid: `Tuple` of the same length as the number of blocks in the last resnet stage. last_stage_repeats: `int`, how many times last stage is repeated. activation: `str` name of the activation function. use_sync_bn: if True, use synchronized batch normalization. norm_momentum: `float` normalization omentum for the moving average. norm_epsilon: `float` small float added to variance to avoid dividing by zero. kernel_initializer: kernel_initializer for convolutional layers. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. Default to None. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Default to None. **kwargs: keyword arguments to be passed. ======= model_id: An `int` specifies depth of ResNet backbone model. output_stride: An `int` of output stride, ratio of input to output resolution. input_specs: A `tf.keras.layers.InputSpec` of the input tensor. stem_type: A `str` of stem type. Can be `standard` or `deeplab`. `deeplab` replaces 7x7 conv by 3 3x3 convs. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. init_stochastic_depth_rate: A `float` of initial stochastic depth rate. multigrid: A tuple of the same length as the number of blocks in the last resnet stage. last_stage_repeats: An `int` that specifies how many times last stage is repeated. activation: A `str` name of the activation function. use_sync_bn: If True, use synchronized batch normalization. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. kernel_initializer: A str for kernel initializer of convolutional layers. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. **kwargs: Additional keyword arguments to be passed. >>>>>>> upstream/master """ self._model_id = model_id self._output_stride = output_stride self._input_specs = input_specs self._use_sync_bn = use_sync_bn self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon if use_sync_bn: self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer self._stem_type = stem_type self._se_ratio = se_ratio self._init_stochastic_depth_rate = init_stochastic_depth_rate if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 # Build ResNet. inputs = tf.keras.Input(shape=input_specs.shape[1:]) if stem_type == 'v0': x = layers.Conv2D(filters=64, kernel_size=7, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) elif stem_type == 'v1': x = layers.Conv2D(filters=64, kernel_size=3, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(inputs) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D(filters=64, kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D(filters=128, kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)(x) x = self._norm(axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = tf_utils.get_activation(activation)(x) else: raise ValueError('Stem type {} not supported.'.format(stem_type)) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2 endpoints = {} for i in range(normal_resnet_stage + 1): spec = RESNET_SPECS[model_id][i] if spec[0] == 'bottleneck': block_fn = nn_blocks.BottleneckBlock else: raise ValueError('Block fn `{}` is not supported.'.format( spec[0])) x = self._block_group( inputs=x, filters=spec[1], strides=(1 if i == 0 else 2), dilation_rate=1, block_fn=block_fn, block_repeats=spec[2], stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 2, 4 + last_stage_repeats), name='block_group_l{}'.format(i + 2)) endpoints[str(i + 2)] = x dilation_rate = 2 for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats): spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[ model_id][-1] if spec[0] == 'bottleneck': block_fn = nn_blocks.BottleneckBlock else: raise ValueError('Block fn `{}` is not supported.'.format( spec[0])) x = self._block_group( inputs=x, filters=spec[1], strides=1, dilation_rate=dilation_rate, block_fn=block_fn, block_repeats=spec[2], stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 2, 4 + last_stage_repeats), multigrid=multigrid if i >= 3 else None, name='block_group_l{}'.format(i + 2)) dilation_rate *= 2 endpoints[str(normal_resnet_stage + 2)] = x self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} super(DilatedResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def __init__( self, model_id: int, input_specs: tf.keras.layers.InputSpec = layers.InputSpec( shape=[None, None, None, 3]), depth_multiplier: float = 1.0, stem_type: str = 'v0', resnetd_shortcut: bool = False, replace_stem_max_pool: bool = False, se_ratio: Optional[float] = None, init_stochastic_depth_rate: float = 0.0, activation: str = 'relu', use_sync_bn: bool = False, norm_momentum: float = 0.99, norm_epsilon: float = 0.001, kernel_initializer: str = 'VarianceScaling', kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, **kwargs): """Initializes a ResNet model. Args: model_id: An `int` of the depth of ResNet backbone model. input_specs: A `tf.keras.layers.InputSpec` of the input tensor. depth_multiplier: A `float` of the depth multiplier to uniformaly scale up all layers in channel size. This argument is also referred to as `width_multiplier` in (https://arxiv.org/abs/2103.07579). stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187). resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in downsampling blocks. replace_stem_max_pool: A `bool` of whether to replace the max pool in stem with a stride-2 conv, se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. init_stochastic_depth_rate: A `float` of initial stochastic depth rate. activation: A `str` name of the activation function. use_sync_bn: If True, use synchronized batch normalization. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A small `float` added to variance to avoid dividing by zero. kernel_initializer: A str for kernel initializer of convolutional layers. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. Default to None. **kwargs: Additional keyword arguments to be passed. """ self._model_id = model_id self._input_specs = input_specs self._depth_multiplier = depth_multiplier self._stem_type = stem_type self._resnetd_shortcut = resnetd_shortcut self._replace_stem_max_pool = replace_stem_max_pool self._se_ratio = se_ratio self._init_stochastic_depth_rate = init_stochastic_depth_rate self._use_sync_bn = use_sync_bn self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon if use_sync_bn: self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 # Build ResNet. inputs = tf.keras.Input(shape=input_specs.shape[1:]) if stem_type == 'v0': x = layers.Conv2D( filters=int(64 * self._depth_multiplier), kernel_size=7, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)( inputs) x = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) elif stem_type == 'v1': x = layers.Conv2D( filters=int(32 * self._depth_multiplier), kernel_size=3, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)( inputs) x = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = layers.Conv2D( filters=int(32 * self._depth_multiplier), kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)( x) x = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = layers.Conv2D( filters=int(64 * self._depth_multiplier), kernel_size=3, strides=1, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)( x) x = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) else: raise ValueError('Stem type {} not supported.'.format(stem_type)) if replace_stem_max_pool: x = layers.Conv2D( filters=int(64 * self._depth_multiplier), kernel_size=3, strides=2, use_bias=False, padding='same', kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer)( x) x = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) else: x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) endpoints = {} for i, spec in enumerate(RESNET_SPECS[model_id]): if spec[0] == 'residual': block_fn = nn_blocks.ResidualBlock elif spec[0] == 'bottleneck': block_fn = nn_blocks.BottleneckBlock else: raise ValueError('Block fn `{}` is not supported.'.format(spec[0])) x = self._block_group( inputs=x, filters=int(spec[1] * self._depth_multiplier), strides=(1 if i == 0 else 2), block_fn=block_fn, block_repeats=spec[2], stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 2, 5), name='block_group_l{}'.format(i + 2)) endpoints[str(i + 2)] = x self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)