def test_invalid_model(self): invalid_mobilenet_size = 'huuuge' with self.assertRaises(KeyError) as exception_context: models.get_keras_model( f'mobilenet_{invalid_mobilenet_size}_1.0_False', 3, 5) if not isinstance(exception_context.exception, KeyError): self.fail()
def get_tflite_friendly_model(checkpoint_folder_path, params, checkpoint_number=None, include_frontend=False): """Given folder & training params, exports SavedModel without frontend.""" compressor = None if params['cop']: compressor = get_default_compressor() static_model = models.get_keras_model( bottleneck_dimension=params['bd'], output_dimension=0, # Don't include the unnecessary final layer. alpha=params['al'], mobilenet_size=params['ms'], frontend=include_frontend, avg_pool=params['ap'], compressor=compressor, quantize_aware_training=params['qat'], tflite=True) checkpoint = tf.train.Checkpoint(model=static_model) if checkpoint_number: checkpoint_to_load = os.path.join(checkpoint_folder_path, f'ckpt-{checkpoint_number}') assert tf.train.load_checkpoint(checkpoint_to_load) else: checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path) checkpoint.restore(checkpoint_to_load).expect_partial() return static_model
def test_model_no_frontend(self): input_tensor = tf.zeros([1, 96, 64, 1], dtype=tf.float32) # log Mel spectrogram m = models.get_keras_model(3, 5, frontend=False) o = m(input_tensor) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def load_and_write_model(logdir, checkpoint_filename, output_directory): model = models.get_keras_model( FLAGS.bottleneck_dimension, FLAGS.output_dimension, alpha=FLAGS.alpha) checkpoint = tf.train.Checkpoint(model=model) checkpoint_to_load = tf.train.latest_checkpoint(logdir, checkpoint_filename) checkpoint.restore(checkpoint_to_load) tf.keras.models.save_model(model, output_directory)
def get_model(checkpoint_folder_path, params, tflite_friendly, checkpoint_number=None, include_frontend=False): """Given folder & training params, exports SavedModel without frontend.""" # Optionally override frontend flags from # `non_semantic_speech_benchmark/export_model/tf_frontend.py` override_flag_names = [ 'frame_hop', 'n_required', 'num_mel_bins', 'frame_width', 'pad_mode' ] for flag_name in override_flag_names: if flag_name in params: setattr(flags.FLAGS, flag_name, params[flag_name]) static_model = models.get_keras_model( params['mt'], output_dimension=1024, truncate_output=params['tr'] if 'tr' in params else False, frontend=include_frontend, tflite=tflite_friendly) checkpoint = tf.train.Checkpoint(model=static_model) if checkpoint_number: checkpoint_to_load = os.path.join(checkpoint_folder_path, f'ckpt-{checkpoint_number}') assert tf.train.load_checkpoint(checkpoint_to_load) else: checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path) checkpoint.restore(checkpoint_to_load).expect_partial() return static_model
def test_tflite_model(self, add_compression): compressor = None bottleneck_dimension = 3 if add_compression: compression_params = compression.CompressionOp.get_default_hparams( ).parse('') compressor = compression_wrapper.get_apply_compression( compression_params, global_step=0) m = models.get_keras_model(bottleneck_dimension, 5, frontend=False, mobilenet_size='small', compressor=compressor, tflite=True) input_tensor = tf.zeros([1, 96, 64, 1], dtype=tf.float32) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) self.assertEqual(emb.shape[0], 1) self.assertEqual(emb.shape[1], bottleneck_dimension) o.shape.assert_has_rank(2) self.assertEqual(o.shape[0], 1) self.assertEqual(o.shape[1], 5) if add_compression: self.assertIsNone(m.get_layer('distilled_output').kernel) self.assertIsNone( m.get_layer('distilled_output').compression_op.a_matrix_tfvar)
def test_valid_mobilenet_size(self): input_tensor = tf.zeros([2, 32000], dtype=tf.float32) for mobilenet_size in ('tiny', 'small', 'large'): m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size) o = m(input_tensor) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def get_model(checkpoint_folder_path, params, tflite_friendly, checkpoint_number = None, include_frontend = False): """Given folder & training params, exports SavedModel without frontend.""" compressor = None if params['cop']: compressor = get_default_compressor() # Optionally override frontend flags from # `non_semantic_speech_benchmark/export_model/tf_frontend.py` override_flag_names = ['frame_hop', 'n_required', 'num_mel_bins', 'frame_width'] for flag_name in override_flag_names: if flag_name in params: setattr(flags.FLAGS, flag_name, params[flag_name]) static_model = models.get_keras_model( params['mt'], bottleneck_dimension=params['bd'], output_dimension=0, # Don't include the unnecessary final layer. frontend=include_frontend, compressor=compressor, quantize_aware_training=params['qat'], tflite=tflite_friendly) checkpoint = tf.train.Checkpoint(model=static_model) if checkpoint_number: checkpoint_to_load = os.path.join( checkpoint_folder_path, f'ckpt-{checkpoint_number}') assert tf.train.load_checkpoint(checkpoint_to_load) else: checkpoint_to_load = tf.train.latest_checkpoint(checkpoint_folder_path) checkpoint.restore(checkpoint_to_load).expect_partial() return static_model
def test_valid_mobilenet_size(self, mobilenet_size): input_tensor = tf.zeros([2, 32000], dtype=tf.float32) m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) self.assertEqual(emb.shape[1], 3) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_get_keras_model_frontend_input_shapes(self, model_type): flags.FLAGS.frame_hop = 5 flags.FLAGS.num_mel_bins = 80 flags.FLAGS.frame_width = 5 flags.FLAGS.n_required = 32000 m = models.get_keras_model(model_type=model_type, output_dimension=0, frontend=True, tflite=False, spec_augment=False) samples = tf.zeros([2, 40000], tf.float32) m(samples)
def test_truncation(self): m = models.get_keras_model('efficientnetb0', frontend=False, output_dimension=5, truncate_output=True) input_tensor = tf.zeros([3, 96, 64, 1], dtype=tf.float32) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] # `embedding` is the original size, but `embedding_to_target` should be the # right size. self.assertEqual(emb.shape[1], 5) self.assertEqual(o.shape[1], 5)
def test_tflite_model(self): m = models.get_keras_model('mobilenet_debug_1.0_False', 5, frontend=False, tflite=True) input_tensor = tf.zeros([2, 96, 64, 1], dtype=tf.float32) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) self.assertEqual(emb.shape[0], 2) o.shape.assert_has_rank(2) self.assertEqual(o.shape[0], 2) self.assertEqual(o.shape[1], 5)
def test_valid_model_type(self, model_type): # Frontend flags. flags.FLAGS.frame_hop = 5 flags.FLAGS.num_mel_bins = 80 flags.FLAGS.frame_width = 5 input_tensor = tf.zeros([2, 16000], dtype=tf.float32) m = models.get_keras_model(model_type, 5, frontend=True, truncate_output=True) o_dict = m(input_tensor) o = o_dict['embedding_to_target'] o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_model_frontend(self, use_frontend, bottleneck): if use_frontend: input_tensor_shape = [2, 32000] # audio signal else: input_tensor_shape = [1, 96, 64, 1] # log Mel spectrogram input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32) m = models.get_keras_model(bottleneck, 5, frontend=use_frontend) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) if bottleneck: self.assertEqual(emb.shape[1], bottleneck) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_model_frontend(self, frontend, bottleneck, tflite): if frontend: input_tensor_shape = [1 if tflite else 2, 32000] # audio signal. else: input_tensor_shape = [3, 96, 64, 1] # log Mel spectrogram. input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32) output_dimension = 5 m = models.get_keras_model( bottleneck, output_dimension, frontend=frontend, tflite=tflite) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) if bottleneck: self.assertEqual(emb.shape[1], bottleneck) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_model_spec_augment(self, frontend, tflite, spec_augment): if frontend: input_tensor_shape = [1 if tflite else 2, 32000] # audio signal. else: input_tensor_shape = [3, 96, 64, 1] # log Mel spectrogram. input_tensor = tf.zeros(input_tensor_shape, dtype=tf.float32) output_dimension = 5 m = models.get_keras_model('mobilenet_debug_1.0_False', output_dimension, frontend=frontend, tflite=tflite, spec_augment=spec_augment) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_model_no_bottleneck(self): input_tensor = tf.zeros([2, 32000], dtype=tf.float32) m = models.get_keras_model(0, 5) o = m(input_tensor) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def test_invalid_mobilenet_size(self): invalid_mobilenet_size = 'huuuge' with self.assertRaises(ValueError) as exception_context: models.get_keras_model(3, 5, mobilenet_size=invalid_mobilenet_size) if not isinstance(exception_context.exception, ValueError): self.fail()
def train_and_report(debug=False): """Trains the classifier.""" logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.train_batch_size) reader = tf.data.TFRecordDataset if FLAGS.precomputed_frontend_and_targets: ds = get_data.get_precomputed_data( file_pattern=FLAGS.file_pattern, output_dimension=FLAGS.output_dimension, frontend_key=FLAGS.frontend_key, target_key=FLAGS.target_key, batch_size=FLAGS.train_batch_size, num_epochs=FLAGS.num_epochs, shuffle_buffer_size=FLAGS.shuffle_buffer_size) ds.element_spec[0].shape.assert_has_rank(3) # log Mel spectrograms ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings else: ds = get_data.get_data(file_pattern=FLAGS.file_pattern, teacher_fn=get_data.savedmodel_to_func( hub.load(FLAGS.teacher_model_hub), FLAGS.output_key), output_dimension=FLAGS.output_dimension, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, batch_size=FLAGS.train_batch_size, loop_forever=True, shuffle=True, shuffle_buffer_size=FLAGS.shuffle_buffer_size) assert len(ds.element_spec) == 2, ds.element_spec ds.element_spec[0].shape.assert_has_rank(2) # audio samples ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings output_dimension = ds.element_spec[1].shape[1] assert output_dimension == FLAGS.output_dimension # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss') opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) global_step = opt.iterations # Create model, loss, and other objects. compressor = None if FLAGS.compression_op: custom_params = ','.join([ 'compression_frequency=%d', 'rank=%d', 'begin_compression_step=%d', 'end_compression_step=%d', 'alpha_decrement_value=%d', ]) % (FLAGS.comp_freq, FLAGS.comp_rank, FLAGS.comp_begin_step, FLAGS.comp_end_step, FLAGS.alpha_step_size) compression_params = compression.CompressionOp.get_default_hparams( ).parse(custom_params) compressor = compression_wrapper.get_apply_compression( compression_params, global_step=global_step) model = models.get_keras_model( bottleneck_dimension=FLAGS.bottleneck_dimension, output_dimension=output_dimension, alpha=FLAGS.alpha, mobilenet_size=FLAGS.mobilenet_size, frontend=not FLAGS.precomputed_frontend_and_targets, avg_pool=FLAGS.average_pool, compressor=compressor, quantize_aware_training=FLAGS.quantize_aware_training) model.summary() # Add additional metrics to track. train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss') train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae, summary_writer) checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager( checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep) logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return for inputs, targets in ds: if FLAGS.precomputed_frontend_and_targets: # inputs are spectrograms inputs.shape.assert_has_rank(3) inputs.shape.assert_is_compatible_with( [FLAGS.train_batch_size, 96, 64]) else: # inputs are audio vectors inputs.shape.assert_has_rank(2) inputs.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.min_length]) targets.shape.assert_has_rank(2) targets.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.output_dimension]) train_step(inputs, targets, global_step) # Optional print output and save model. if global_step % 10 == 0: logging.info('step: %i, train loss: %f, train mean abs error: %f', global_step, train_loss.result(), train_mae.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) logging.info('Finished training.')
def load_and_write_model(keras_model_args, checkpoint_to_load, output_directory): model = models.get_keras_model(**keras_model_args) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(checkpoint_to_load).expect_partial() tf.keras.models.save_model(model, output_directory)
def eval_and_report(): """Eval on voxceleb.""" tf.logging.info('samples_key: %s', FLAGS.samples_key) logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.batch_size) writer = tf.summary.create_file_writer(FLAGS.eval_dir) model = models.get_keras_model( bottleneck_dimension=FLAGS.bottleneck_dimension, output_dimension=FLAGS.output_dimension, alpha=FLAGS.alpha, mobilenet_size=FLAGS.mobilenet_size, frontend=not FLAGS.precomputed_frontend_and_targets, avg_pool=FLAGS.average_pool) checkpoint = tf.train.Checkpoint(model=model) for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir, timeout=FLAGS.timeout): assert 'ckpt-' in ckpt, ckpt step = ckpt.split('ckpt-')[-1] logging.info('Starting to evaluate step: %s.', step) checkpoint.restore(ckpt) logging.info('Loaded weights for eval step: %s.', step) reader = tf.data.TFRecordDataset ds = get_data.get_data(file_pattern=FLAGS.file_pattern, teacher_fn=get_data.savedmodel_to_func( hub.load(FLAGS.teacher_model_hub), FLAGS.output_key), output_dimension=FLAGS.output_dimension, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, batch_size=FLAGS.batch_size, loop_forever=False, shuffle=False) logging.info('Got dataset for eval step: %s.', step) if FLAGS.take_fixed_data: ds = ds.take(FLAGS.take_fixed_data) mse_m = tf.keras.metrics.MeanSquaredError() mae_m = tf.keras.metrics.MeanAbsoluteError() logging.info('Starting the ds loop...') count, ex_count = 0, 0 s = time.time() for wav_samples, targets in ds: wav_samples.shape.assert_is_compatible_with( [None, FLAGS.min_length]) targets.shape.assert_is_compatible_with( [None, FLAGS.output_dimension]) logits = model(wav_samples, training=False)['embedding_to_target'] logits.shape.assert_is_compatible_with(targets.shape) mse_m.update_state(y_true=targets, y_pred=logits) mae_m.update_state(y_true=targets, y_pred=logits) ex_count += logits.shape[0] count += 1 logging.info('Saw %i examples after %i iterations as %.2f secs...', ex_count, count, time.time() - s) with writer.as_default(): tf.summary.scalar('mse', mse_m.result().numpy(), step=int(step)) tf.summary.scalar('mae', mae_m.result().numpy(), step=int(step)) logging.info('Done with eval step: %s in %.2f secs.', step, time.time() - s)
def train_and_report(debug=False): """Trains the classifier.""" logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.train_batch_size) reader = tf.data.TFRecordDataset target_key = FLAGS.target_key if FLAGS.precomputed_targets: teacher_fn = None assert target_key is not None assert FLAGS.output_key is None else: teacher_fn = get_data.savedmodel_to_func( hub.load(FLAGS.teacher_model_hub), FLAGS.output_key) assert target_key is None ds = get_data.get_data(file_patterns=FLAGS.file_patterns, output_dimension=FLAGS.output_dimension, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, batch_size=FLAGS.train_batch_size, loop_forever=True, shuffle=True, teacher_fn=teacher_fn, target_key=target_key, normalize_to_pm_one=FLAGS.normalize_to_pm_one, shuffle_buffer_size=FLAGS.shuffle_buffer_size) assert len(ds.element_spec) == 2, ds.element_spec ds.element_spec[0].shape.assert_has_rank(2) # audio samples ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings output_dimension = ds.element_spec[1].shape[1] assert output_dimension == FLAGS.output_dimension # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss') opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) global_step = opt.iterations # Create model, loss, and other objects. model = models.get_keras_model(model_type=FLAGS.model_type, output_dimension=output_dimension, truncate_output=FLAGS.truncate_output, frontend=True, spec_augment=FLAGS.spec_augment) model.summary() # Add additional metrics to track. train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss') train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae, summary_writer) checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager( checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep) logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return for inputs, targets in ds: # Inputs are audio vectors. inputs.shape.assert_has_rank(2) inputs.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.min_length]) targets.shape.assert_has_rank(2) targets.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.output_dimension]) train_step(inputs, targets, global_step) # Optional print output and save model. if global_step % 10 == 0: logging.info('step: %i, train loss: %f, train mean abs error: %f', global_step, train_loss.result(), train_mae.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) logging.info('Finished training.')
def test_valid_mobilenet_size(self, mobilenet_size): input_tensor = tf.zeros([2, 32000], dtype=tf.float32) m = models.get_keras_model(3, 5, mobilenet_size=mobilenet_size) o = m(input_tensor) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)
def _model(): return models.get_keras_model(model_type='efficientnetv2b0', output_dimension=0, frontend=True, tflite=False, spec_augment=False)
def test_model_frontend(self): input_tensor = tf.zeros([2, 32000], dtype=tf.float32) # audio signal m = models.get_keras_model(3, 5) o = m(input_tensor) o.shape.assert_has_rank(2) self.assertEqual(o.shape[1], 5)