Exemplo n.º 1
0
  def test_stream_classifier_head(self):
    head = movinet_layers.Head(project_filters=5)
    classifier_head = movinet_layers.ClassifierHead(
        head_filters=10, num_classes=4)

    inputs = tf.range(4, dtype=tf.float32) + 1.
    inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
    inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
    x, _ = head(inputs)
    expected = classifier_head(x)

    for num_splits in [1, 2, 4]:
      frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
      states = {}
      for frame in frames:
        x, states = head(frame, states=states)
        predicted = classifier_head(x)

      self.assertEqual(predicted.shape, expected.shape)
      self.assertAllClose(predicted, expected)
Exemplo n.º 2
0
    def _build_network(
        self,
        backbone: tf.keras.Model,
        input_specs: Mapping[str, tf.keras.layers.InputSpec],
        state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
    ) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[  # pytype: disable=invalid-annotation  # typed-keras
            str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str,
                                                               tf.Tensor]]]:
        """Builds the model network.

    Args:
      backbone: the model backbone.
      input_specs: the model input spec to use.
      state_specs: a dict of states such that, if any of the keys match for a
        layer, will overwrite the contents of the buffer(s).

    Returns:
      Inputs and outputs as a tuple. Inputs are expected to be a dict with
      base input and states. Outputs are expected to be a dict of endpoints
      and (optionally) output states.
    """
        inputs, endpoints, states = self._build_backbone(
            backbone=backbone,
            input_specs=input_specs,
            state_specs=state_specs)
        x = endpoints['head']

        x = movinet_layers.ClassifierHead(
            head_filters=backbone.head_filters,
            num_classes=self._num_classes,
            dropout_rate=self._dropout_rate,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            conv_type=backbone.conv_type,
            activation=self._activation)(x)

        outputs = (x, states) if self._output_states else x

        return inputs, outputs