예제 #1
0
  def test_external_streaming_shapes(self, model_name):
    model = utils.get_model_with_default_params(
        model_name, mode=modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE)

    # The first 'n' inputs correspond to the 'n' inputs that the model takes
    # in non-streaming mode. The rest of the input tensors represent the
    # internal states for each layer in the model.
    inputs = [np.zeros(shape, dtype=np.float32) for shape in model.input_shapes]
    outputs = model.predict(inputs)
    for output, expected_shape in zip(outputs, model.output_shapes):
      self.assertEqual(output.shape, expected_shape)
    def __init__(self):
        super().__init__()
        self.m = utils.get_model_with_default_params(
            FLAGS.model, MODE_ENUM_TO_MODE[FLAGS.mode])

        call = lambda *args: self.m(*args, training=False)
        input_signature = [
            tf.TensorSpec(tensor.shape) for tensor in self.m.inputs
        ]
        self.call = tf_test_utils.tf_function_unit_test(
            input_signature=input_signature, name="call", atol=1e-5)(call)
 def create_module(cls):
     model = utils.get_model_with_default_params(
         FLAGS.model, MODE_ENUM_TO_MODE[FLAGS.mode])
     cls.input_shapes = [tensor.shape for tensor in model.inputs]
     return cls(model)