def testWrapper_CreatesProperCompressorOption1(self, low_rank_mock):
    hparams = self._create_compression_op_spec(1)
    mock_compressor = MatrixCompressorInterfaceMock(
        self._default_compressor_spec(hparams))
    low_rank_mock.side_effect = [mock_compressor]

    with mock.patch.object(comp_op, 'ApplyCompression') as apply_mock:
      compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP)
      apply_mock.assert_called_with(
          scope='default_scope',
          compression_spec=hparams,
          compressor=mock_compressor,
          global_step=_GLOBAL_STEP)
Esempio n. 2
0
    def test_tflite_model(self, add_compression):
        compressor = None
        bottleneck_dimension = 3
        if add_compression:
            compression_params = compression.CompressionOp.get_default_hparams(
            ).parse('')
            compressor = compression_wrapper.get_apply_compression(
                compression_params, global_step=0)
        m = models.get_keras_model(bottleneck_dimension,
                                   5,
                                   frontend=False,
                                   mobilenet_size='small',
                                   compressor=compressor,
                                   tflite=True)

        input_tensor = tf.zeros([1, 96, 64, 1], dtype=tf.float32)
        o_dict = m(input_tensor)
        emb, o = o_dict['embedding'], o_dict['embedding_to_target']

        emb.shape.assert_has_rank(2)
        self.assertEqual(emb.shape[0], 1)
        self.assertEqual(emb.shape[1], bottleneck_dimension)
        o.shape.assert_has_rank(2)
        self.assertEqual(o.shape[0], 1)
        self.assertEqual(o.shape[1], 5)

        if add_compression:
            self.assertIsNone(m.get_layer('distilled_output').kernel)
            self.assertIsNone(
                m.get_layer('distilled_output').compression_op.a_matrix_tfvar)
Esempio n. 3
0
    def test_compressed_dense_inference(self):
        """Verify forward pass and removal of the uncompressed kernel."""
        compression_params = compression.CompressionOp \
          .get_default_hparams().parse("")
        compressor = compression_wrapper.get_apply_compression(
            compression_params, global_step=0)

        in_tensor = tf.zeros((1, 5), dtype=tf.float32)
        m = tf.keras.Sequential([
            tf.keras.layers.Input((5, )),
            layers.CompressedDense(10,
                                   compression_obj=compressor,
                                   name="compressed")
        ])

        # remove uncompressed kernel
        m.get_layer("compressed").kernel = None
        m.get_layer("compressed").compression_op.a_matrix_tfvar = None

        out = m(in_tensor, training=False)
        self.assertEqual(out.shape[1], 10)
Esempio n. 4
0
    def test_compressed_dense_training_failure(self):
        """Verify forward pass fails when training flag is True."""
        compression_params = compression.CompressionOp\
          .get_default_hparams().parse("")
        compressor = compression_wrapper.get_apply_compression(
            compression_params, global_step=0)

        in_tensor = tf.zeros((1, 5), dtype=tf.float32)
        m = tf.keras.Sequential([
            tf.keras.layers.Input((5, )),
            layers.CompressedDense(10,
                                   compression_obj=compressor,
                                   name="compressed")
        ])

        # remove uncompressed kernel
        m.get_layer("compressed").kernel = None
        m.get_layer("compressed").compression_op.a_matrix_tfvar = None

        with self.assertRaises(ValueError) as exception_context:
            m(in_tensor, training=True)
        if not isinstance(exception_context.exception, ValueError):
            self.fail()
Esempio n. 5
0
def train_and_report(debug=False):
    """Trains the classifier."""
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.train_batch_size)

    reader = tf.data.TFRecordDataset
    if FLAGS.precomputed_frontend_and_targets:
        ds = get_data.get_precomputed_data(
            file_pattern=FLAGS.file_pattern,
            output_dimension=FLAGS.output_dimension,
            frontend_key=FLAGS.frontend_key,
            target_key=FLAGS.target_key,
            batch_size=FLAGS.train_batch_size,
            num_epochs=FLAGS.num_epochs,
            shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        ds.element_spec[0].shape.assert_has_rank(3)  # log Mel spectrograms
        ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    else:
        ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                               teacher_fn=get_data.savedmodel_to_func(
                                   hub.load(FLAGS.teacher_model_hub),
                                   FLAGS.output_key),
                               output_dimension=FLAGS.output_dimension,
                               reader=reader,
                               samples_key=FLAGS.samples_key,
                               min_length=FLAGS.min_length,
                               batch_size=FLAGS.train_batch_size,
                               loop_forever=True,
                               shuffle=True,
                               shuffle_buffer_size=FLAGS.shuffle_buffer_size)
        assert len(ds.element_spec) == 2, ds.element_spec
        ds.element_spec[0].shape.assert_has_rank(2)  # audio samples
        ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    output_dimension = ds.element_spec[1].shape[1]
    assert output_dimension == FLAGS.output_dimension
    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss')
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    global_step = opt.iterations
    # Create model, loss, and other objects.
    compressor = None
    if FLAGS.compression_op:
        custom_params = ','.join([
            'compression_frequency=%d',
            'rank=%d',
            'begin_compression_step=%d',
            'end_compression_step=%d',
            'alpha_decrement_value=%d',
        ]) % (FLAGS.comp_freq, FLAGS.comp_rank, FLAGS.comp_begin_step,
              FLAGS.comp_end_step, FLAGS.alpha_step_size)
        compression_params = compression.CompressionOp.get_default_hparams(
        ).parse(custom_params)
        compressor = compression_wrapper.get_apply_compression(
            compression_params, global_step=global_step)
    model = models.get_keras_model(
        bottleneck_dimension=FLAGS.bottleneck_dimension,
        output_dimension=output_dimension,
        alpha=FLAGS.alpha,
        mobilenet_size=FLAGS.mobilenet_size,
        frontend=not FLAGS.precomputed_frontend_and_targets,
        avg_pool=FLAGS.average_pool,
        compressor=compressor,
        quantize_aware_training=FLAGS.quantize_aware_training)
    model.summary()
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss')
    train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae,
                                summary_writer)
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(
        checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep)
    logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    for inputs, targets in ds:
        if FLAGS.precomputed_frontend_and_targets:  # inputs are spectrograms
            inputs.shape.assert_has_rank(3)
            inputs.shape.assert_is_compatible_with(
                [FLAGS.train_batch_size, 96, 64])
        else:  # inputs are audio vectors
            inputs.shape.assert_has_rank(2)
            inputs.shape.assert_is_compatible_with(
                [FLAGS.train_batch_size, FLAGS.min_length])
        targets.shape.assert_has_rank(2)
        targets.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, FLAGS.output_dimension])
        train_step(inputs, targets, global_step)
        # Optional print output and save model.
        if global_step % 10 == 0:
            logging.info('step: %i, train loss: %f, train mean abs error: %f',
                         global_step, train_loss.result(), train_mae.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    logging.info('Finished training.')
def get_default_compressor():
  compression_params = compression.CompressionOp.get_default_hparams().parse('')
  compressor = compression_wrapper.get_apply_compression(
      compression_params, global_step=0)
  return compressor