Exemple #1
0
    def _build_network(
        self,
        input_specs: tf.keras.layers.InputSpec,
        state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
    ) -> Tuple[TensorMap, Union[TensorMap, Tuple[TensorMap, TensorMap]]]:
        """Builds the model network.

    Args:
      input_specs: the model input spec to use.
      state_specs: a dict mapping a state name to the corresponding state spec.
        State names should match with the `state` input/output dict.

    Returns:
      Inputs and outputs as a tuple. Inputs are expected to be a dict with
      base input and states. Outputs are expected to be a dict of endpoints
      and (optional) output states.
    """
        state_specs = state_specs if state_specs is not None else {}

        image_input = tf.keras.Input(shape=input_specs.shape[1:],
                                     name='inputs')

        states = {
            name: tf.keras.Input(shape=spec.shape[1:],
                                 dtype=spec.dtype,
                                 name=name)
            for name, spec in state_specs.items()
        }

        inputs = {**states, 'image': image_input}
        endpoints = {}

        x = image_input

        num_layers = sum(
            len(block.expand_filters) for block in self._block_specs
            if isinstance(block, MovinetBlockSpec))
        stochastic_depth_idx = 1
        for block_idx, block in enumerate(self._block_specs):
            if isinstance(block, StemSpec):
                layer_obj = movinet_layers.Stem(
                    block.filters,
                    block.kernel_size,
                    block.strides,
                    conv_type=self._conv_type,
                    causal=self._causal,
                    activation=self._activation,
                    kernel_initializer=self._kernel_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    batch_norm_layer=self._norm,
                    batch_norm_momentum=self._norm_momentum,
                    batch_norm_epsilon=self._norm_epsilon,
                    state_prefix='state_stem',
                    name='stem')
                x, states = layer_obj(x, states=states)
                endpoints['stem'] = x
            elif isinstance(block, MovinetBlockSpec):
                if not (len(block.expand_filters) == len(block.kernel_sizes) ==
                        len(block.strides)):
                    raise ValueError(
                        'Lengths of block parameters differ: {}, {}, {}'.
                        format(len(block.expand_filters),
                               len(block.kernel_sizes), len(block.strides)))
                params = list(
                    zip(block.expand_filters, block.kernel_sizes,
                        block.strides))
                for layer_idx, layer in enumerate(params):
                    stochastic_depth_drop_rate = (
                        self._stochastic_depth_drop_rate *
                        stochastic_depth_idx / num_layers)
                    expand_filters, kernel_size, strides = layer
                    name = f'block{block_idx-1}_layer{layer_idx}'
                    layer_obj = movinet_layers.MovinetBlock(
                        block.base_filters,
                        expand_filters,
                        kernel_size=kernel_size,
                        strides=strides,
                        causal=self._causal,
                        activation=self._activation,
                        gating_activation=self._gating_activation,
                        stochastic_depth_drop_rate=stochastic_depth_drop_rate,
                        conv_type=self._conv_type,
                        se_type=self._se_type,
                        use_positional_encoding=self._use_positional_encoding
                        and self._causal,
                        kernel_initializer=self._kernel_initializer,
                        kernel_regularizer=self._kernel_regularizer,
                        batch_norm_layer=self._norm,
                        batch_norm_momentum=self._norm_momentum,
                        batch_norm_epsilon=self._norm_epsilon,
                        state_prefix=f'state_{name}',
                        name=name)
                    x, states = layer_obj(x, states=states)

                    endpoints[name] = x
                    stochastic_depth_idx += 1
            elif isinstance(block, HeadSpec):
                layer_obj = movinet_layers.Head(
                    project_filters=block.project_filters,
                    conv_type=self._conv_type,
                    activation=self._activation,
                    kernel_initializer=self._kernel_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    batch_norm_layer=self._norm,
                    batch_norm_momentum=self._norm_momentum,
                    batch_norm_epsilon=self._norm_epsilon,
                    state_prefix='state_head',
                    name='head')
                x, states = layer_obj(x, states=states)
                endpoints['head'] = x
            else:
                raise ValueError('Unknown block type {}'.format(block))

        outputs = (endpoints, states) if self._output_states else endpoints

        return inputs, outputs
Exemple #2
0
  def __init__(self,
               model_id: str = 'a0',
               causal: bool = False,
               use_positional_encoding: bool = False,
               conv_type: str = '3d',
               input_specs: Optional[tf.keras.layers.InputSpec] = None,
               activation: str = 'swish',
               use_sync_bn: bool = True,
               norm_momentum: float = 0.99,
               norm_epsilon: float = 0.001,
               kernel_initializer: str = 'HeNormal',
               kernel_regularizer: Optional[str] = None,
               bias_regularizer: Optional[str] = None,
               stochastic_depth_drop_rate: float = 0.,
               **kwargs):
    """MoViNet initialization function.

    Args:
      model_id: name of MoViNet backbone model.
      causal: use causal mode, with CausalConv and CausalSE operations.
      use_positional_encoding:  if True, adds a positional encoding before
          temporal convolutions and the cumulative global average pooling
          layers.
      conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' configures the network
        to use the default 3D convolution. '2plus1d' uses (2+1)D convolution
        with Conv2D operations and 2D reshaping (e.g., a 5x3x3 kernel becomes
        3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
        Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
        by 5x1x1 conv).
      input_specs: the model input spec to use.
      activation: name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: normalization momentum for the moving average.
      norm_epsilon: small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
        Defaults to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
        Defaults to None.
      stochastic_depth_drop_rate: the base rate for stochastic depth.
      **kwargs: keyword arguments to be passed.
    """
    block_specs = BLOCK_SPECS[model_id]
    if input_specs is None:
      input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, None, 3])

    if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
      raise ValueError('Unknown conv type: {}'.format(conv_type))

    self._model_id = model_id
    self._block_specs = block_specs
    self._causal = causal
    self._use_positional_encoding = use_positional_encoding
    self._conv_type = conv_type
    self._input_specs = input_specs
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    if use_sync_bn:
      self._norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf.keras.layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate

    if not isinstance(block_specs[0], StemSpec):
      raise ValueError(
          'Expected first spec to be StemSpec, got {}'.format(block_specs[0]))
    if not isinstance(block_specs[-1], HeadSpec):
      raise ValueError(
          'Expected final spec to be HeadSpec, got {}'.format(block_specs[-1]))
    self._head_filters = block_specs[-1].head_filters

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

    # Build MoViNet backbone.
    inputs = tf.keras.Input(shape=input_specs.shape[1:], name='inputs')

    x = inputs
    states = {}
    endpoints = {}

    num_layers = sum(len(block.expand_filters) for block in block_specs
                     if isinstance(block, MovinetBlockSpec))
    stochastic_depth_idx = 1
    for block_idx, block in enumerate(block_specs):
      if isinstance(block, StemSpec):
        x, states = movinet_layers.Stem(
            block.filters,
            block.kernel_size,
            block.strides,
            conv_type=self._conv_type,
            causal=self._causal,
            activation=self._activation,
            kernel_initializer=kernel_initializer,
            kernel_regularizer=kernel_regularizer,
            batch_norm_layer=self._norm,
            batch_norm_momentum=self._norm_momentum,
            batch_norm_epsilon=self._norm_epsilon,
            name='stem')(x, states=states)
        endpoints['stem'] = x
      elif isinstance(block, MovinetBlockSpec):
        if not (len(block.expand_filters) == len(block.kernel_sizes) ==
                len(block.strides)):
          raise ValueError(
              'Lenths of block parameters differ: {}, {}, {}'.format(
                  len(block.expand_filters),
                  len(block.kernel_sizes),
                  len(block.strides)))
        params = list(zip(block.expand_filters,
                          block.kernel_sizes,
                          block.strides))
        for layer_idx, layer in enumerate(params):
          stochastic_depth_drop_rate = (
              self._stochastic_depth_drop_rate * stochastic_depth_idx /
              num_layers)
          expand_filters, kernel_size, strides = layer
          name = f'b{block_idx-1}/l{layer_idx}'
          x, states = movinet_layers.MovinetBlock(
              block.base_filters,
              expand_filters,
              kernel_size=kernel_size,
              strides=strides,
              causal=self._causal,
              activation=self._activation,
              stochastic_depth_drop_rate=stochastic_depth_drop_rate,
              conv_type=self._conv_type,
              use_positional_encoding=
              self._use_positional_encoding and self._causal,
              kernel_initializer=kernel_initializer,
              kernel_regularizer=kernel_regularizer,
              batch_norm_layer=self._norm,
              batch_norm_momentum=self._norm_momentum,
              batch_norm_epsilon=self._norm_epsilon,
              name=name)(x, states=states)
          endpoints[name] = x
          stochastic_depth_idx += 1
      elif isinstance(block, HeadSpec):
        x, states = movinet_layers.Head(
            project_filters=block.project_filters,
            conv_type=self._conv_type,
            activation=self._activation,
            kernel_initializer=kernel_initializer,
            kernel_regularizer=kernel_regularizer,
            batch_norm_layer=self._norm,
            batch_norm_momentum=self._norm_momentum,
            batch_norm_epsilon=self._norm_epsilon)(x, states=states)
        endpoints['head'] = x
      else:
        raise ValueError('Unknown block type {}'.format(block))

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    inputs = {
        'image': inputs,
        'states': {
            name: tf.keras.Input(shape=state.shape[1:], name=f'states/{name}')
            for name, state in states.items()
        },
    }
    outputs = (endpoints, states)

    super(Movinet, self).__init__(inputs=inputs, outputs=outputs, **kwargs)