def setUp(self):
    super(DsTcResnetTest, self).setUp()

    config = tf1.ConfigProto()
    config.gpu_options.allow_growth = True
    self.sess = tf1.Session(config=config)
    tf1.keras.backend.set_session(self.sess)
    tf.keras.backend.set_learning_phase(0)

    test_utils.set_seed(123)
    self.params = utils.ds_tc_resnet_model_params(True)

    self.model = ds_tc_resnet.model(self.params)
    self.model.summary()

    self.input_data = np.random.rand(self.params.batch_size,
                                     self.params.desired_samples)

    # run non streaming inference
    self.non_stream_out = self.model.predict(self.input_data)
    def test_ds_tc_resnet_stream_internal_tflite(self):
        """Test tflite streaming with internal state."""
        test_utils.set_seed(123)
        tf.keras.backend.set_learning_phase(0)
        params = utils.ds_tc_resnet_model_params(True)

        model = ds_tc_resnet.model(params)
        model.summary()

        input_data = np.random.rand(params.batch_size, params.desired_samples)

        # run non streaming inference
        non_stream_out = model.predict(input_data)

        tflite_streaming_model = utils.model_to_tflite(
            None, model, params, Modes.STREAM_INTERNAL_STATE_INFERENCE)

        interpreter = tf.lite.Interpreter(model_content=tflite_streaming_model)
        interpreter.allocate_tensors()

        stream_out = inference.run_stream_inference_classification_tflite(
            params, interpreter, input_data, input_states=None)

        self.assertAllClose(stream_out, non_stream_out, atol=1e-5)
示例#3
0
  def setUp(self):
    super(DsTcResnetTest, self).setUp()

    config = tf1.ConfigProto()
    config.gpu_options.allow_growth = True
    self.sess = tf1.Session(config=config)
    tf1.keras.backend.set_session(self.sess)
    test_utils.set_seed(123)
    tf.keras.backend.set_learning_phase(0)

    # model parameters
    model_name = 'ds_tc_resnet'
    self.params = model_params.HOTWORD_MODEL_PARAMS[model_name]
    self.params.clip_duration_ms = 160
    self.params.window_size_ms = 4.0
    self.params.window_stride_ms = 2.0
    self.params.wanted_words = 'a,b,c'
    self.params.ds_padding = "'causal','causal','causal'"
    self.params.ds_filters = '8,8,4'
    self.params.ds_repeat = '1,1,1'
    self.params.ds_residual = '0,1,1'  # residual can not be applied with stride
    self.params.ds_kernel_size = '3,3,3'
    self.params.ds_stride = '2,1,1'  # streaming conv with stride
    self.params.ds_dilation = '1,1,1'
    self.params.ds_pool = '1,2,1'  # streaming conv with pool
    self.params.ds_filter_separable = '1,1,1'

    # convert ms to samples and compute labels count
    self.params = model_flags.update_flags(self.params)

    # compute total stride
    pools = utils.parse(self.params.ds_pool)
    strides = utils.parse(self.params.ds_stride)
    time_stride = [1]
    for pool in pools:
      if pool > 1:
        time_stride.append(pool)
    for stride in strides:
      if stride > 1:
        time_stride.append(stride)
    total_stride = np.prod(time_stride)

    # overide input data shape for streaming model with stride/pool
    self.params.data_stride = total_stride
    self.params.data_frame_padding = 'causal'

    # set desired number of frames in model
    frames_number = 16
    frames_per_call = total_stride
    frames_number = (frames_number // frames_per_call) * frames_per_call
    # number of input audio samples required to produce one output frame
    framing_stride = max(
        self.params.window_stride_samples,
        max(0, self.params.window_size_samples -
            self.params.window_stride_samples))
    signal_size = framing_stride * frames_number

    # desired number of samples in the input data to train non streaming model
    self.params.desired_samples = signal_size

    self.params.batch_size = 1
    self.model = ds_tc_resnet.model(self.params)
    self.model.summary()

    self.input_data = np.random.rand(self.params.batch_size,
                                     self.params.desired_samples)

    # run non streaming inference
    self.non_stream_out = self.model.predict(self.input_data)