def testWrapper_CreatesProperCompressorOption1(self, low_rank_mock): hparams = self._create_compression_op_spec(1) mock_compressor = MatrixCompressorInterfaceMock( self._default_compressor_spec(hparams)) low_rank_mock.side_effect = [mock_compressor] with mock.patch.object(comp_op, 'ApplyCompression') as apply_mock: compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP) apply_mock.assert_called_with( scope='default_scope', compression_spec=hparams, compressor=mock_compressor, global_step=_GLOBAL_STEP)
def test_tflite_model(self, add_compression): compressor = None bottleneck_dimension = 3 if add_compression: compression_params = compression.CompressionOp.get_default_hparams( ).parse('') compressor = compression_wrapper.get_apply_compression( compression_params, global_step=0) m = models.get_keras_model(bottleneck_dimension, 5, frontend=False, mobilenet_size='small', compressor=compressor, tflite=True) input_tensor = tf.zeros([1, 96, 64, 1], dtype=tf.float32) o_dict = m(input_tensor) emb, o = o_dict['embedding'], o_dict['embedding_to_target'] emb.shape.assert_has_rank(2) self.assertEqual(emb.shape[0], 1) self.assertEqual(emb.shape[1], bottleneck_dimension) o.shape.assert_has_rank(2) self.assertEqual(o.shape[0], 1) self.assertEqual(o.shape[1], 5) if add_compression: self.assertIsNone(m.get_layer('distilled_output').kernel) self.assertIsNone( m.get_layer('distilled_output').compression_op.a_matrix_tfvar)
def test_compressed_dense_inference(self): """Verify forward pass and removal of the uncompressed kernel.""" compression_params = compression.CompressionOp \ .get_default_hparams().parse("") compressor = compression_wrapper.get_apply_compression( compression_params, global_step=0) in_tensor = tf.zeros((1, 5), dtype=tf.float32) m = tf.keras.Sequential([ tf.keras.layers.Input((5, )), layers.CompressedDense(10, compression_obj=compressor, name="compressed") ]) # remove uncompressed kernel m.get_layer("compressed").kernel = None m.get_layer("compressed").compression_op.a_matrix_tfvar = None out = m(in_tensor, training=False) self.assertEqual(out.shape[1], 10)
def test_compressed_dense_training_failure(self): """Verify forward pass fails when training flag is True.""" compression_params = compression.CompressionOp\ .get_default_hparams().parse("") compressor = compression_wrapper.get_apply_compression( compression_params, global_step=0) in_tensor = tf.zeros((1, 5), dtype=tf.float32) m = tf.keras.Sequential([ tf.keras.layers.Input((5, )), layers.CompressedDense(10, compression_obj=compressor, name="compressed") ]) # remove uncompressed kernel m.get_layer("compressed").kernel = None m.get_layer("compressed").compression_op.a_matrix_tfvar = None with self.assertRaises(ValueError) as exception_context: m(in_tensor, training=True) if not isinstance(exception_context.exception, ValueError): self.fail()
def train_and_report(debug=False): """Trains the classifier.""" logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.train_batch_size) reader = tf.data.TFRecordDataset if FLAGS.precomputed_frontend_and_targets: ds = get_data.get_precomputed_data( file_pattern=FLAGS.file_pattern, output_dimension=FLAGS.output_dimension, frontend_key=FLAGS.frontend_key, target_key=FLAGS.target_key, batch_size=FLAGS.train_batch_size, num_epochs=FLAGS.num_epochs, shuffle_buffer_size=FLAGS.shuffle_buffer_size) ds.element_spec[0].shape.assert_has_rank(3) # log Mel spectrograms ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings else: ds = get_data.get_data(file_pattern=FLAGS.file_pattern, teacher_fn=get_data.savedmodel_to_func( hub.load(FLAGS.teacher_model_hub), FLAGS.output_key), output_dimension=FLAGS.output_dimension, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, batch_size=FLAGS.train_batch_size, loop_forever=True, shuffle=True, shuffle_buffer_size=FLAGS.shuffle_buffer_size) assert len(ds.element_spec) == 2, ds.element_spec ds.element_spec[0].shape.assert_has_rank(2) # audio samples ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings output_dimension = ds.element_spec[1].shape[1] assert output_dimension == FLAGS.output_dimension # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss') opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) global_step = opt.iterations # Create model, loss, and other objects. compressor = None if FLAGS.compression_op: custom_params = ','.join([ 'compression_frequency=%d', 'rank=%d', 'begin_compression_step=%d', 'end_compression_step=%d', 'alpha_decrement_value=%d', ]) % (FLAGS.comp_freq, FLAGS.comp_rank, FLAGS.comp_begin_step, FLAGS.comp_end_step, FLAGS.alpha_step_size) compression_params = compression.CompressionOp.get_default_hparams( ).parse(custom_params) compressor = compression_wrapper.get_apply_compression( compression_params, global_step=global_step) model = models.get_keras_model( bottleneck_dimension=FLAGS.bottleneck_dimension, output_dimension=output_dimension, alpha=FLAGS.alpha, mobilenet_size=FLAGS.mobilenet_size, frontend=not FLAGS.precomputed_frontend_and_targets, avg_pool=FLAGS.average_pool, compressor=compressor, quantize_aware_training=FLAGS.quantize_aware_training) model.summary() # Add additional metrics to track. train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss') train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae, summary_writer) checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager( checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep) logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return for inputs, targets in ds: if FLAGS.precomputed_frontend_and_targets: # inputs are spectrograms inputs.shape.assert_has_rank(3) inputs.shape.assert_is_compatible_with( [FLAGS.train_batch_size, 96, 64]) else: # inputs are audio vectors inputs.shape.assert_has_rank(2) inputs.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.min_length]) targets.shape.assert_has_rank(2) targets.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.output_dimension]) train_step(inputs, targets, global_step) # Optional print output and save model. if global_step % 10 == 0: logging.info('step: %i, train loss: %f, train mean abs error: %f', global_step, train_loss.result(), train_mae.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) logging.info('Finished training.')
def get_default_compressor(): compression_params = compression.CompressionOp.get_default_hparams().parse('') compressor = compression_wrapper.get_apply_compression( compression_params, global_step=0) return compressor