def _build_network( self, input_specs: tf.keras.layers.InputSpec, state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None, ) -> Tuple[TensorMap, Union[TensorMap, Tuple[TensorMap, TensorMap]]]: """Builds the model network. Args: input_specs: the model input spec to use. state_specs: a dict mapping a state name to the corresponding state spec. State names should match with the `state` input/output dict. Returns: Inputs and outputs as a tuple. Inputs are expected to be a dict with base input and states. Outputs are expected to be a dict of endpoints and (optional) output states. """ state_specs = state_specs if state_specs is not None else {} image_input = tf.keras.Input(shape=input_specs.shape[1:], name='inputs') states = { name: tf.keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name) for name, spec in state_specs.items() } inputs = {**states, 'image': image_input} endpoints = {} x = image_input num_layers = sum( len(block.expand_filters) for block in self._block_specs if isinstance(block, MovinetBlockSpec)) stochastic_depth_idx = 1 for block_idx, block in enumerate(self._block_specs): if isinstance(block, StemSpec): layer_obj = movinet_layers.Stem( block.filters, block.kernel_size, block.strides, conv_type=self._conv_type, causal=self._causal, activation=self._activation, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon, state_prefix='state_stem', name='stem') x, states = layer_obj(x, states=states) endpoints['stem'] = x elif isinstance(block, MovinetBlockSpec): if not (len(block.expand_filters) == len(block.kernel_sizes) == len(block.strides)): raise ValueError( 'Lengths of block parameters differ: {}, {}, {}'. format(len(block.expand_filters), len(block.kernel_sizes), len(block.strides))) params = list( zip(block.expand_filters, block.kernel_sizes, block.strides)) for layer_idx, layer in enumerate(params): stochastic_depth_drop_rate = ( self._stochastic_depth_drop_rate * stochastic_depth_idx / num_layers) expand_filters, kernel_size, strides = layer name = f'block{block_idx-1}_layer{layer_idx}' layer_obj = movinet_layers.MovinetBlock( block.base_filters, expand_filters, kernel_size=kernel_size, strides=strides, causal=self._causal, activation=self._activation, gating_activation=self._gating_activation, stochastic_depth_drop_rate=stochastic_depth_drop_rate, conv_type=self._conv_type, se_type=self._se_type, use_positional_encoding=self._use_positional_encoding and self._causal, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon, state_prefix=f'state_{name}', name=name) x, states = layer_obj(x, states=states) endpoints[name] = x stochastic_depth_idx += 1 elif isinstance(block, HeadSpec): layer_obj = movinet_layers.Head( project_filters=block.project_filters, conv_type=self._conv_type, activation=self._activation, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon, state_prefix='state_head', name='head') x, states = layer_obj(x, states=states) endpoints['head'] = x else: raise ValueError('Unknown block type {}'.format(block)) outputs = (endpoints, states) if self._output_states else endpoints return inputs, outputs
def __init__(self, model_id: str = 'a0', causal: bool = False, use_positional_encoding: bool = False, conv_type: str = '3d', input_specs: Optional[tf.keras.layers.InputSpec] = None, activation: str = 'swish', use_sync_bn: bool = True, norm_momentum: float = 0.99, norm_epsilon: float = 0.001, kernel_initializer: str = 'HeNormal', kernel_regularizer: Optional[str] = None, bias_regularizer: Optional[str] = None, stochastic_depth_drop_rate: float = 0., **kwargs): """MoViNet initialization function. Args: model_id: name of MoViNet backbone model. causal: use causal mode, with CausalConv and CausalSE operations. use_positional_encoding: if True, adds a positional encoding before temporal convolutions and the cumulative global average pooling layers. conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' configures the network to use the default 3D convolution. '2plus1d' uses (2+1)D convolution with Conv2D operations and 2D reshaping (e.g., a 5x3x3 kernel becomes 3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed by 5x1x1 conv). input_specs: the model input spec to use. activation: name of the activation function. use_sync_bn: if True, use synchronized batch normalization. norm_momentum: normalization momentum for the moving average. norm_epsilon: small float added to variance to avoid dividing by zero. kernel_initializer: kernel_initializer for convolutional layers. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. Defaults to None. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Defaults to None. stochastic_depth_drop_rate: the base rate for stochastic depth. **kwargs: keyword arguments to be passed. """ block_specs = BLOCK_SPECS[model_id] if input_specs is None: input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, None, 3]) if conv_type not in ('3d', '2plus1d', '3d_2plus1d'): raise ValueError('Unknown conv type: {}'.format(conv_type)) self._model_id = model_id self._block_specs = block_specs self._causal = causal self._use_positional_encoding = use_positional_encoding self._conv_type = conv_type self._input_specs = input_specs self._use_sync_bn = use_sync_bn self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon if use_sync_bn: self._norm = tf.keras.layers.experimental.SyncBatchNormalization else: self._norm = tf.keras.layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer self._stochastic_depth_drop_rate = stochastic_depth_drop_rate if not isinstance(block_specs[0], StemSpec): raise ValueError( 'Expected first spec to be StemSpec, got {}'.format(block_specs[0])) if not isinstance(block_specs[-1], HeadSpec): raise ValueError( 'Expected final spec to be HeadSpec, got {}'.format(block_specs[-1])) self._head_filters = block_specs[-1].head_filters if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 # Build MoViNet backbone. inputs = tf.keras.Input(shape=input_specs.shape[1:], name='inputs') x = inputs states = {} endpoints = {} num_layers = sum(len(block.expand_filters) for block in block_specs if isinstance(block, MovinetBlockSpec)) stochastic_depth_idx = 1 for block_idx, block in enumerate(block_specs): if isinstance(block, StemSpec): x, states = movinet_layers.Stem( block.filters, block.kernel_size, block.strides, conv_type=self._conv_type, causal=self._causal, activation=self._activation, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon, name='stem')(x, states=states) endpoints['stem'] = x elif isinstance(block, MovinetBlockSpec): if not (len(block.expand_filters) == len(block.kernel_sizes) == len(block.strides)): raise ValueError( 'Lenths of block parameters differ: {}, {}, {}'.format( len(block.expand_filters), len(block.kernel_sizes), len(block.strides))) params = list(zip(block.expand_filters, block.kernel_sizes, block.strides)) for layer_idx, layer in enumerate(params): stochastic_depth_drop_rate = ( self._stochastic_depth_drop_rate * stochastic_depth_idx / num_layers) expand_filters, kernel_size, strides = layer name = f'b{block_idx-1}/l{layer_idx}' x, states = movinet_layers.MovinetBlock( block.base_filters, expand_filters, kernel_size=kernel_size, strides=strides, causal=self._causal, activation=self._activation, stochastic_depth_drop_rate=stochastic_depth_drop_rate, conv_type=self._conv_type, use_positional_encoding= self._use_positional_encoding and self._causal, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon, name=name)(x, states=states) endpoints[name] = x stochastic_depth_idx += 1 elif isinstance(block, HeadSpec): x, states = movinet_layers.Head( project_filters=block.project_filters, conv_type=self._conv_type, activation=self._activation, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, batch_norm_layer=self._norm, batch_norm_momentum=self._norm_momentum, batch_norm_epsilon=self._norm_epsilon)(x, states=states) endpoints['head'] = x else: raise ValueError('Unknown block type {}'.format(block)) self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} inputs = { 'image': inputs, 'states': { name: tf.keras.Input(shape=state.shape[1:], name=f'states/{name}') for name, state in states.items() }, } outputs = (endpoints, states) super(Movinet, self).__init__(inputs=inputs, outputs=outputs, **kwargs)