def test_movinet_export_a0_stream_with_tfhub(self):
    saved_model_path = self.get_temp_dir()

    FLAGS.export_path = saved_model_path
    FLAGS.model_id = 'a0'
    FLAGS.causal = True
    FLAGS.num_classes = 600

    export_saved_model.main('unused_args')

    encoder = hub.KerasLayer(saved_model_path, trainable=True)

    image_input = tf.keras.layers.Input(
        shape=[None, None, None, 3],
        dtype=tf.float32,
        name='image')

    init_states_fn = encoder.resolved_object.signatures['init_states']
    state_shapes = {
        name: ([s if s > 0 else None for s in state.shape], state.dtype)
        for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()
    }
    states_input = {
        name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
        for name, (shape, dtype) in state_shapes.items()
    }

    inputs = {**states_input, 'image': image_input}

    outputs = encoder(inputs)

    model = tf.keras.Model(inputs, outputs)

    example_input = tf.ones([1, 8, 172, 172, 3])
    frames = tf.split(example_input, example_input.shape[1], axis=1)

    init_states = init_states_fn(tf.shape(example_input))

    expected_outputs, _ = model({**init_states, 'image': example_input})

    states = init_states
    for frame in frames:
      outputs, states = model({**states, 'image': frame})

    self.assertAllEqual(outputs.shape, [1, 600])
    self.assertNotEmpty(states)
    self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
  def test_movinet_export_a0_stream_with_tflite(self):
    saved_model_path = self.get_temp_dir()

    FLAGS.export_path = saved_model_path
    FLAGS.model_id = 'a0'
    FLAGS.causal = True
    FLAGS.conv_type = '2plus1d'
    FLAGS.se_type = '2plus3d'
    FLAGS.activation = 'hard_swish'
    FLAGS.gating_activation = 'hard_sigmoid'
    FLAGS.use_positional_encoding = False
    FLAGS.num_classes = 600
    FLAGS.batch_size = 1
    FLAGS.num_frames = 1
    FLAGS.image_size = 172
    FLAGS.bundle_input_init_states_fn = False

    export_saved_model.main('unused_args')

    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
    tflite_model = converter.convert()

    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    signature = interpreter.get_signature_runner()

    def state_name(name: str) -> str:
      return name[len('serving_default_'):-len(':0')]

    init_states = {
        state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
        for x in interpreter.get_input_details()
    }
    del init_states['image']

    video = tf.ones([1, 8, 172, 172, 3])
    clips = tf.split(video, video.shape[1], axis=1)

    states = init_states
    for clip in clips:
      outputs = signature(**states, image=clip)
      logits = outputs.pop('logits')
      states = outputs

    self.assertAllEqual(logits.shape, [1, 600])
    self.assertNotEmpty(states)
    def test_movinet_export_a0_base_with_tfhub(self):
        saved_model_path = self.get_temp_dir()

        FLAGS.export_path = saved_model_path
        FLAGS.model_id = 'a0'
        FLAGS.causal = False
        FLAGS.num_classes = 600

        export_saved_model.main('unused_args')

        encoder = hub.KerasLayer(saved_model_path, trainable=True)

        inputs = tf.keras.layers.Input(shape=[None, None, None, 3],
                                       dtype=tf.float32)

        outputs = encoder(dict(image=inputs))

        model = tf.keras.Model(inputs, outputs)

        example_input = tf.ones([1, 8, 172, 172, 3])
        outputs = model(example_input)

        self.assertEqual(outputs.shape, [1, 600])