示例#1
0
  def test_stream_movinet_block_none_se(self):
    block = movinet_layers.MovinetBlock(
        out_filters=3,
        expand_filters=6,
        kernel_size=(3, 3, 3),
        strides=(1, 2, 2),
        causal=True,
        se_type='none',
        state_prefix='test',
    )

    inputs = tf.range(4, dtype=tf.float32) + 1.
    inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
    inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
    expected, expected_states = block(inputs)

    for num_splits in [1, 2, 4]:
      frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
      states = {}
      predicted = []
      for frame in frames:
        x, states = block(frame, states=states)
        predicted.append(x)
      predicted = tf.concat(predicted, axis=1)

      self.assertEqual(predicted.shape, expected.shape)
      self.assertAllClose(predicted, expected)
    self.assertAllEqual(list(expected_states.keys()), ['test_stream_buffer'])
示例#2
0
  def _build_network(
      self,
      input_specs: tf.keras.layers.InputSpec,
      state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
  ) -> Tuple[TensorMap, Union[TensorMap, Tuple[TensorMap, TensorMap]]]:
    """Builds the model network.

    Args:
      input_specs: the model input spec to use.
      state_specs: a dict mapping a state name to the corresponding state spec.
        State names should match with the `state` input/output dict.

    Returns:
      Inputs and outputs as a tuple. Inputs are expected to be a dict with
      base input and states. Outputs are expected to be a dict of endpoints
      and (optional) output states.
    """
    state_specs = state_specs if state_specs is not None else {}

    image_input = tf.keras.Input(shape=input_specs.shape[1:], name='inputs')

    states = {
        name: tf.keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name)
        for name, spec in state_specs.items()
    }

    inputs = {**states, 'image': image_input}
    endpoints = {}

    x = image_input

    num_layers = sum(
        len(block.expand_filters)
        for block in self._block_specs
        if isinstance(block, MovinetBlockSpec))
    stochastic_depth_idx = 1
    for block_idx, block in enumerate(self._block_specs):
      if isinstance(block, StemSpec):
        layer_obj = movinet_layers.Stem(
            block.filters,
            block.kernel_size,
            block.strides,
            conv_type=self._conv_type,
            causal=self._causal,
            activation=self._activation,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            batch_norm_layer=self._norm,
            batch_norm_momentum=self._norm_momentum,
            batch_norm_epsilon=self._norm_epsilon,
            state_prefix='state_stem',
            name='stem')
        x, states = layer_obj(x, states=states)
        endpoints['stem'] = x
      elif isinstance(block, MovinetBlockSpec):
        if not (len(block.expand_filters) == len(block.kernel_sizes) ==
                len(block.strides)):
          raise ValueError(
              'Lengths of block parameters differ: {}, {}, {}'.format(
                  len(block.expand_filters),
                  len(block.kernel_sizes),
                  len(block.strides)))
        params = list(zip(block.expand_filters,
                          block.kernel_sizes,
                          block.strides))
        for layer_idx, layer in enumerate(params):
          stochastic_depth_drop_rate = (
              self._stochastic_depth_drop_rate * stochastic_depth_idx /
              num_layers)
          expand_filters, kernel_size, strides = layer
          name = f'block{block_idx-1}_layer{layer_idx}'
          layer_obj = movinet_layers.MovinetBlock(
              block.base_filters,
              expand_filters,
              kernel_size=kernel_size,
              strides=strides,
              causal=self._causal,
              activation=self._activation,
              gating_activation=self._gating_activation,
              stochastic_depth_drop_rate=stochastic_depth_drop_rate,
              conv_type=self._conv_type,
              se_type=self._se_type,
              use_positional_encoding=
              self._use_positional_encoding and self._causal,
              kernel_initializer=self._kernel_initializer,
              kernel_regularizer=self._kernel_regularizer,
              batch_norm_layer=self._norm,
              batch_norm_momentum=self._norm_momentum,
              batch_norm_epsilon=self._norm_epsilon,
              state_prefix=f'state_{name}',
              name=name)
          x, states = layer_obj(x, states=states)

          endpoints[name] = x
          stochastic_depth_idx += 1
      elif isinstance(block, HeadSpec):
        layer_obj = movinet_layers.Head(
            project_filters=block.project_filters,
            conv_type=self._conv_type,
            activation=self._activation,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            batch_norm_layer=self._norm,
            batch_norm_momentum=self._norm_momentum,
            batch_norm_epsilon=self._norm_epsilon,
            average_pooling_type=self._average_pooling_type,
            state_prefix='state_head',
            name='head')
        x, states = layer_obj(x, states=states)
        endpoints['head'] = x
      else:
        raise ValueError('Unknown block type {}'.format(block))

    outputs = (endpoints, states) if self._output_states else endpoints

    return inputs, outputs