Exemplo n.º 1
0
 def test_invalid_model(self):
     invalid_mobilenet_size = 'huuuge'
     with self.assertRaises(KeyError) as exception_context:
         models.get_keras_model(
             f'mobilenet_{invalid_mobilenet_size}_1.0_False', 3, 5)
     if not isinstance(exception_context.exception, KeyError):
         self.fail()
def get_tflite_friendly_model(checkpoint_folder_path,
                              params,
                              checkpoint_number=None,
                              include_frontend=False):
    """Given folder & training params, exports SavedModel without frontend."""
    compressor = None
    if params['cop']:
        compressor = get_default_compressor()
    static_model = models.get_keras_model(
        bottleneck_dimension=params['bd'],
        output_dimension=0,  # Don't include the unnecessary final layer.
        alpha=params['al'],
        mobilenet_size=params['ms'],
        frontend=include_frontend,
        avg_pool=params['ap'],
        compressor=compressor,
        quantize_aware_training=params['qat'],
        tflite=True)
    checkpoint = tf.train.Checkpoint(model=static_model)
    if checkpoint_number:
        checkpoint_to_load = os.path.join(checkpoint_folder_path,
                                          f'ckpt-{checkpoint_number}')
        assert tf.train.load_checkpoint(checkpoint_to_load)
    else:
        checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path)
    checkpoint.restore(checkpoint_to_load).expect_partial()
    return static_model
Exemplo n.º 3
0
 def test_model_no_frontend(self):
     input_tensor = tf.zeros([1, 96, 64, 1],
                             dtype=tf.float32)  # log Mel spectrogram
     m = models.get_keras_model(3, 5, frontend=False)
     o = m(input_tensor)
     o.shape.assert_has_rank(2)
     self.assertEqual(o.shape[1], 5)
Exemplo n.º 4
0
def load_and_write_model(logdir, checkpoint_filename, output_directory):
  model = models.get_keras_model(
      FLAGS.bottleneck_dimension, FLAGS.output_dimension, alpha=FLAGS.alpha)
  checkpoint = tf.train.Checkpoint(model=model)
  checkpoint_to_load = tf.train.latest_checkpoint(logdir, checkpoint_filename)
  checkpoint.restore(checkpoint_to_load)
  tf.keras.models.save_model(model, output_directory)
def get_model(checkpoint_folder_path,
              params,
              tflite_friendly,
              checkpoint_number=None,
              include_frontend=False):
    """Given folder & training params, exports SavedModel without frontend."""
    # Optionally override frontend flags from
    # `non_semantic_speech_benchmark/export_model/tf_frontend.py`
    override_flag_names = [
        'frame_hop', 'n_required', 'num_mel_bins', 'frame_width', 'pad_mode'
    ]
    for flag_name in override_flag_names:
        if flag_name in params:
            setattr(flags.FLAGS, flag_name, params[flag_name])

    static_model = models.get_keras_model(
        params['mt'],
        output_dimension=1024,
        truncate_output=params['tr'] if 'tr' in params else False,
        frontend=include_frontend,
        tflite=tflite_friendly)
    checkpoint = tf.train.Checkpoint(model=static_model)
    if checkpoint_number:
        checkpoint_to_load = os.path.join(checkpoint_folder_path,
                                          f'ckpt-{checkpoint_number}')
        assert tf.train.load_checkpoint(checkpoint_to_load)
    else:
        checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path)
    checkpoint.restore(checkpoint_to_load).expect_partial()
    return static_model
Exemplo n.º 6
0
    def test_tflite_model(self, add_compression):
        compressor = None
        bottleneck_dimension = 3
        if add_compression:
            compression_params = compression.CompressionOp.get_default_hparams(
            ).parse('')
            compressor = compression_wrapper.get_apply_compression(
                compression_params, global_step=0)
        m = models.get_keras_model(bottleneck_dimension,
                                   5,
                                   frontend=False,
                                   mobilenet_size='small',
                                   compressor=compressor,
                                   tflite=True)

        input_tensor = tf.zeros([1, 96, 64, 1], dtype=tf.float32)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        self.assertEqual(emb.shape[0], 1)
        self.assertEqual(emb.shape[1], bottleneck_dimension)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[0], 1)
        self.assertEqual(o.shape[1], 5)

        if add_compression:
            self.assertIsNone(m.get_layer('distilled_output').kernel)
            self.assertIsNone(
                m.get_layer('distilled_output').compression_op.a_matrix_tfvar)
Exemplo n.º 7
0
 def test_valid_mobilenet_size(self):
     input_tensor = tf.zeros([2, 32000], dtype=tf.float32)
     for mobilenet_size in ('tiny', 'small', 'large'):
         m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size)
         o = m(input_tensor)
         o.shape.assert_has_rank(2)
         self.assertEqual(o.shape[1], 5)
def get_model(checkpoint_folder_path,
              params,
              tflite_friendly,
              checkpoint_number = None,
              include_frontend = False):
  """Given folder & training params, exports SavedModel without frontend."""
  compressor = None
  if params['cop']:
    compressor = get_default_compressor()
  # Optionally override frontend flags from
  # `non_semantic_speech_benchmark/export_model/tf_frontend.py`
  override_flag_names = ['frame_hop', 'n_required', 'num_mel_bins',
                         'frame_width']
  for flag_name in override_flag_names:
    if flag_name in params:
      setattr(flags.FLAGS, flag_name, params[flag_name])

  static_model = models.get_keras_model(
      params['mt'],
      bottleneck_dimension=params['bd'],
      output_dimension=0,  # Don't include the unnecessary final layer.
      frontend=include_frontend,
      compressor=compressor,
      quantize_aware_training=params['qat'],
      tflite=tflite_friendly)
  checkpoint = tf.train.Checkpoint(model=static_model)
  if checkpoint_number:
    checkpoint_to_load = os.path.join(
        checkpoint_folder_path, f'ckpt-{checkpoint_number}')
    assert tf.train.load_checkpoint(checkpoint_to_load)
  else:
    checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path)
  checkpoint.restore(checkpoint_to_load).expect_partial()
  return static_model
Exemplo n.º 9
0
    def test_valid_mobilenet_size(self, mobilenet_size):
        input_tensor = tf.zeros([2, 32000], dtype=tf.float32)
        m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        self.assertEqual(emb.shape[1], 3)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 10
0
 def test_get_keras_model_frontend_input_shapes(self, model_type):
     flags.FLAGS.frame_hop = 5
     flags.FLAGS.num_mel_bins = 80
     flags.FLAGS.frame_width = 5
     flags.FLAGS.n_required = 32000
     m = models.get_keras_model(model_type=model_type,
                                output_dimension=0,
                                frontend=True,
                                tflite=False,
                                spec_augment=False)
     samples = tf.zeros([2, 40000], tf.float32)
     m(samples)
Exemplo n.º 11
0
    def test_truncation(self):
        m = models.get_keras_model('efficientnetb0',
                                   frontend=False,
                                   output_dimension=5,
                                   truncate_output=True)

        input_tensor = tf.zeros([3, 96, 64, 1], dtype=tf.float32)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        # `embedding` is the original size, but `embedding_to_target` should be the
        # right size.
        self.assertEqual(emb.shape[1], 5)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 12
0
    def test_tflite_model(self):
        m = models.get_keras_model('mobilenet_debug_1.0_False',
                                   5,
                                   frontend=False,
                                   tflite=True)

        input_tensor = tf.zeros([2, 96, 64, 1], dtype=tf.float32)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        self.assertEqual(emb.shape[0], 2)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[0], 2)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 13
0
    def test_valid_model_type(self, model_type):
        # Frontend flags.
        flags.FLAGS.frame_hop = 5
        flags.FLAGS.num_mel_bins = 80
        flags.FLAGS.frame_width = 5

        input_tensor = tf.zeros([2, 16000], dtype=tf.float32)
        m = models.get_keras_model(model_type,
                                   5,
                                   frontend=True,
                                   truncate_output=True)
        o_dict = m(input_tensor)
        o = o_dict['embedding_to_target']

        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 14
0
    def test_model_frontend(self, use_frontend, bottleneck):
        if use_frontend:
            input_tensor_shape = [2, 32000]  # audio signal
        else:
            input_tensor_shape = [1, 96, 64, 1]  # log Mel spectrogram
        input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32)

        m = models.get_keras_model(bottleneck, 5, frontend=use_frontend)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        if bottleneck:
            self.assertEqual(emb.shape[1], bottleneck)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 15
0
  def test_model_frontend(self, frontend, bottleneck, tflite):
    if frontend:
      input_tensor_shape = [1 if tflite else 2, 32000]  # audio signal.
    else:
      input_tensor_shape = [3, 96, 64, 1]  # log Mel spectrogram.
    input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32)
    output_dimension = 5

    m = models.get_keras_model(
        bottleneck, output_dimension, frontend=frontend, tflite=tflite)
    o_dict = m(input_tensor)
    emb, o = o_dict['embedding'], o_dict['embedding_to_target']

    emb.shape.assert_has_rank(2)
    if bottleneck:
      self.assertEqual(emb.shape[1], bottleneck)
    o.shape.assert_has_rank(2)
    self.assertEqual(o.shape[1], 5)
Exemplo n.º 16
0
    def test_model_spec_augment(self, frontend, tflite, spec_augment):
        if frontend:
            input_tensor_shape = [1 if tflite else 2, 32000]  # audio signal.
        else:
            input_tensor_shape = [3, 96, 64, 1]  # log Mel spectrogram.
        input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32)
        output_dimension = 5

        m = models.get_keras_model('mobilenet_debug_1.0_False',
                                   output_dimension,
                                   frontend=frontend,
                                   tflite=tflite,
                                   spec_augment=spec_augment)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[1], 5)
Exemplo n.º 17
0
 def test_model_no_bottleneck(self):
     input_tensor = tf.zeros([2, 32000], dtype=tf.float32)
     m = models.get_keras_model(0, 5)
     o = m(input_tensor)
     o.shape.assert_has_rank(2)
     self.assertEqual(o.shape[1], 5)
Exemplo n.º 18
0
 def test_invalid_mobilenet_size(self):
     invalid_mobilenet_size = 'huuuge'
     with self.assertRaises(ValueError) as exception_context:
         models.get_keras_model(3, 5, mobilenet_size=invalid_mobilenet_size)
     if not isinstance(exception_context.exception, ValueError):
         self.fail()
Exemplo n.º 19
0
def train_and_report(debug=False):
    """Trains the classifier."""
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.train_batch_size)

    reader = tf.data.TFRecordDataset
    if FLAGS.precomputed_frontend_and_targets:
        ds = get_data.get_precomputed_data(
            file_pattern=FLAGS.file_pattern,
            output_dimension=FLAGS.output_dimension,
            frontend_key=FLAGS.frontend_key,
            target_key=FLAGS.target_key,
            batch_size=FLAGS.train_batch_size,
            num_epochs=FLAGS.num_epochs,
            shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        ds.element_spec[0].shape.assert_has_rank(3)  # log Mel spectrograms
        ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    else:
        ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                               teacher_fn=get_data.savedmodel_to_func(
                                   hub.load(FLAGS.teacher_model_hub),
                                   FLAGS.output_key),
                               output_dimension=FLAGS.output_dimension,
                               reader=reader,
                               samples_key=FLAGS.samples_key,
                               min_length=FLAGS.min_length,
                               batch_size=FLAGS.train_batch_size,
                               loop_forever=True,
                               shuffle=True,
                               shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        assert len(ds.element_spec) == 2, ds.element_spec
        ds.element_spec[0].shape.assert_has_rank(2)  # audio samples
        ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    output_dimension = ds.element_spec[1].shape[1]
    assert output_dimension == FLAGS.output_dimension
    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss')
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    global_step = opt.iterations
    # Create model, loss, and other objects.
    compressor = None
    if FLAGS.compression_op:
        custom_params = ','.join([
            'compression_frequency=%d',
            'rank=%d',
            'begin_compression_step=%d',
            'end_compression_step=%d',
            'alpha_decrement_value=%d',
        ]) % (FLAGS.comp_freq, FLAGS.comp_rank, FLAGS.comp_begin_step,
              FLAGS.comp_end_step, FLAGS.alpha_step_size)
        compression_params = compression.CompressionOp.get_default_hparams(
        ).parse(custom_params)
        compressor = compression_wrapper.get_apply_compression(
            compression_params, global_step=global_step)
    model = models.get_keras_model(
        bottleneck_dimension=FLAGS.bottleneck_dimension,
        output_dimension=output_dimension,
        alpha=FLAGS.alpha,
        mobilenet_size=FLAGS.mobilenet_size,
        frontend=not FLAGS.precomputed_frontend_and_targets,
        avg_pool=FLAGS.average_pool,
        compressor=compressor,
        quantize_aware_training=FLAGS.quantize_aware_training)
    model.summary()
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss')
    train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae,
                                summary_writer)
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(
        checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep)
    logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    for inputs, targets in ds:
        if FLAGS.precomputed_frontend_and_targets:  # inputs are spectrograms
            inputs.shape.assert_has_rank(3)
            inputs.shape.assert_is_compatible_with(
                [FLAGS.train_batch_size, 96, 64])
        else:  # inputs are audio vectors
            inputs.shape.assert_has_rank(2)
            inputs.shape.assert_is_compatible_with(
                [FLAGS.train_batch_size, FLAGS.min_length])
        targets.shape.assert_has_rank(2)
        targets.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, FLAGS.output_dimension])
        train_step(inputs, targets, global_step)
        # Optional print output and save model.
        if global_step % 10 == 0:
            logging.info('step: %i, train loss: %f, train mean abs error: %f',
                         global_step, train_loss.result(), train_mae.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    logging.info('Finished training.')
Exemplo n.º 20
0
def load_and_write_model(keras_model_args, checkpoint_to_load,
                         output_directory):
  model = models.get_keras_model(**keras_model_args)
  checkpoint = tf.train.Checkpoint(model=model)
  checkpoint.restore(checkpoint_to_load).expect_partial()
  tf.keras.models.save_model(model, output_directory)
Exemplo n.º 21
0
def eval_and_report():
    """Eval on voxceleb."""
    tf.logging.info('samples_key: %s', FLAGS.samples_key)
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.batch_size)

    writer = tf.summary.create_file_writer(FLAGS.eval_dir)
    model = models.get_keras_model(
        bottleneck_dimension=FLAGS.bottleneck_dimension,
        output_dimension=FLAGS.output_dimension,
        alpha=FLAGS.alpha,
        mobilenet_size=FLAGS.mobilenet_size,
        frontend=not FLAGS.precomputed_frontend_and_targets,
        avg_pool=FLAGS.average_pool)
    checkpoint = tf.train.Checkpoint(model=model)

    for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir,
                                              timeout=FLAGS.timeout):
        assert 'ckpt-' in ckpt, ckpt
        step = ckpt.split('ckpt-')[-1]
        logging.info('Starting to evaluate step: %s.', step)

        checkpoint.restore(ckpt)

        logging.info('Loaded weights for eval step: %s.', step)

        reader = tf.data.TFRecordDataset
        ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                               teacher_fn=get_data.savedmodel_to_func(
                                   hub.load(FLAGS.teacher_model_hub),
                                   FLAGS.output_key),
                               output_dimension=FLAGS.output_dimension,
                               reader=reader,
                               samples_key=FLAGS.samples_key,
                               min_length=FLAGS.min_length,
                               batch_size=FLAGS.batch_size,
                               loop_forever=False,
                               shuffle=False)
        logging.info('Got dataset for eval step: %s.', step)
        if FLAGS.take_fixed_data:
            ds = ds.take(FLAGS.take_fixed_data)

        mse_m = tf.keras.metrics.MeanSquaredError()
        mae_m = tf.keras.metrics.MeanAbsoluteError()

        logging.info('Starting the ds loop...')
        count, ex_count = 0, 0
        s = time.time()
        for wav_samples, targets in ds:
            wav_samples.shape.assert_is_compatible_with(
                [None, FLAGS.min_length])
            targets.shape.assert_is_compatible_with(
                [None, FLAGS.output_dimension])

            logits = model(wav_samples, training=False)['embedding_to_target']
            logits.shape.assert_is_compatible_with(targets.shape)

            mse_m.update_state(y_true=targets, y_pred=logits)
            mae_m.update_state(y_true=targets, y_pred=logits)
            ex_count += logits.shape[0]
            count += 1
            logging.info('Saw %i examples after %i iterations as %.2f secs...',
                         ex_count, count,
                         time.time() - s)
        with writer.as_default():
            tf.summary.scalar('mse', mse_m.result().numpy(), step=int(step))
            tf.summary.scalar('mae', mae_m.result().numpy(), step=int(step))
        logging.info('Done with eval step: %s in %.2f secs.', step,
                     time.time() - s)
Exemplo n.º 22
0
def train_and_report(debug=False):
    """Trains the classifier."""
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.train_batch_size)

    reader = tf.data.TFRecordDataset
    target_key = FLAGS.target_key
    if FLAGS.precomputed_targets:
        teacher_fn = None
        assert target_key is not None
        assert FLAGS.output_key is None
    else:
        teacher_fn = get_data.savedmodel_to_func(
            hub.load(FLAGS.teacher_model_hub), FLAGS.output_key)
        assert target_key is None
    ds = get_data.get_data(file_patterns=FLAGS.file_patterns,
                           output_dimension=FLAGS.output_dimension,
                           reader=reader,
                           samples_key=FLAGS.samples_key,
                           min_length=FLAGS.min_length,
                           batch_size=FLAGS.train_batch_size,
                           loop_forever=True,
                           shuffle=True,
                           teacher_fn=teacher_fn,
                           target_key=target_key,
                           normalize_to_pm_one=FLAGS.normalize_to_pm_one,
                           shuffle_buffer_size=FLAGS.shuffle_buffer_size)
    assert len(ds.element_spec) == 2, ds.element_spec
    ds.element_spec[0].shape.assert_has_rank(2)  # audio samples
    ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    output_dimension = ds.element_spec[1].shape[1]
    assert output_dimension == FLAGS.output_dimension

    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss')
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    global_step = opt.iterations
    # Create model, loss, and other objects.
    model = models.get_keras_model(model_type=FLAGS.model_type,
                                   output_dimension=output_dimension,
                                   truncate_output=FLAGS.truncate_output,
                                   frontend=True,
                                   spec_augment=FLAGS.spec_augment)
    model.summary()
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss')
    train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae,
                                summary_writer)
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(
        checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep)
    logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    for inputs, targets in ds:
        # Inputs are audio vectors.
        inputs.shape.assert_has_rank(2)
        inputs.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, FLAGS.min_length])
        targets.shape.assert_has_rank(2)
        targets.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, FLAGS.output_dimension])
        train_step(inputs, targets, global_step)
        # Optional print output and save model.
        if global_step % 10 == 0:
            logging.info('step: %i, train loss: %f, train mean abs error: %f',
                         global_step, train_loss.result(), train_mae.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    logging.info('Finished training.')
Exemplo n.º 23
0
 def test_valid_mobilenet_size(self, mobilenet_size):
     input_tensor = tf.zeros([2, 32000], dtype=tf.float32)
     m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size)
     o = m(input_tensor)
     o.shape.assert_has_rank(2)
     self.assertEqual(o.shape[1], 5)
Exemplo n.º 24
0
 def _model():
     return models.get_keras_model(model_type='efficientnetv2b0',
                                   output_dimension=0,
                                   frontend=True,
                                   tflite=False,
                                   spec_augment=False)
Exemplo n.º 25
0
 def test_model_frontend(self):
     input_tensor = tf.zeros([2, 32000], dtype=tf.float32)  # audio signal
     m = models.get_keras_model(3, 5)
     o = m(input_tensor)
     o.shape.assert_has_rank(2)
     self.assertEqual(o.shape[1], 5)