Example #1
0
 def test_get_data_max_len(self):
     bs = 3
     max_samples_length = 3
     assert max_samples_length < self.len
     ds = get_data.get_data(file_patterns=self.precomputed_file_pattern,
                            output_dimension=self.output_dim,
                            reader=tf.data.TFRecordDataset,
                            samples_key=self.samples_key,
                            target_key=self.target_key,
                            batch_size=bs,
                            loop_forever=False,
                            shuffle=True,
                            shuffle_buffer_size=2,
                            min_samples_length=1,
                            max_samples_length=max_samples_length)
     self.assertLen(ds.element_spec, 2)
     # Test that one element of the input pipeline can be successfully read.
     # Also, test that the shape is truncated to the shortest in the minibatch.
     has_elements = False
     for wav_samples, targets in ds:
         self.assertEqual(wav_samples.shape, [bs, max_samples_length])
         self.assertEqual(targets.shape, [bs, self.output_dim])
         has_elements = True
         break
     self.assertTrue(has_elements)
Example #2
0
 def test_get_data_out_len(self):
     bs = 3
     ds = get_data.get_data(file_patterns=self.precomputed_file_pattern,
                            output_dimension=self.output_dim,
                            reader=tf.data.TFRecordDataset,
                            samples_key=self.samples_key,
                            target_key=self.target_key,
                            batch_size=bs,
                            loop_forever=False,
                            shuffle=True,
                            shuffle_buffer_size=2,
                            label_key='label')
     self.assertLen(ds.element_spec, 3)
Example #3
0
def _get_ds(file_patterns, step):
    """Gets a tf.Dataset for a file."""
    ds = get_data.get_data(file_patterns=file_patterns,
                           reader=tf.data.TFRecordDataset,
                           samples_key=AUDIO_KEY_,
                           batch_size=FLAGS.eval_batch_size,
                           loop_forever=False,
                           shuffle=False,
                           target_key=FLAGS.target_key,
                           label_key=FLAGS.label_key,
                           speaker_id_key=FLAGS.speaker_id_key,
                           samples_are_float=False,
                           max_samples_length=None)
    logging.info('Got dataset for eval step: %s.', step)
    if FLAGS.take_fixed_data:
        ds = ds.take(FLAGS.take_fixed_data)
    return ds
Example #4
0
 def test_get_data_min_len(self):
     bs = 3
     min_samples_length = 1000
     ds = get_data.get_data(file_patterns=self.precomputed_file_pattern,
                            output_dimension=self.output_dim,
                            reader=tf.data.TFRecordDataset,
                            samples_key=self.samples_key,
                            target_key=self.target_key,
                            batch_size=bs,
                            loop_forever=False,
                            shuffle=True,
                            shuffle_buffer_size=2,
                            min_samples_length=min_samples_length)
     has_elements = False
     for _ in ds:
         has_elements = True
         break
     self.assertFalse(has_elements)
Example #5
0
def train_and_report(debug=False, target_dim=1024):
    """Trains the classifier."""
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.train_batch_size)

    reader = tf.data.TFRecordDataset
    target_key = FLAGS.target_key

    ds = get_data.get_data(file_patterns=FLAGS.file_patterns,
                           output_dimension=target_dim,
                           reader=reader,
                           samples_key=FLAGS.samples_key,
                           target_key=target_key,
                           batch_size=FLAGS.train_batch_size,
                           loop_forever=True,
                           shuffle=True,
                           max_samples_length=FLAGS.max_sample_length,
                           shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                           samples_are_float=True)
    assert len(ds.element_spec) == 2, ds.element_spec
    ds.element_spec[0].shape.assert_has_rank(2)  # audio samples
    ds.element_spec[1].shape.assert_has_rank(2)  # teacher embeddings
    output_dimension = ds.element_spec[1].shape[1]
    assert output_dimension == target_dim, (output_dimension, target_dim)

    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss')
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    global_step = opt.iterations
    # Create model, loss, and other objects.
    model = models.get_keras_model(model_type=FLAGS.model_type,
                                   frame_hop=FLAGS.frame_hop)
    assert model.trainable_variables
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss')
    train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae,
                                summary_writer)
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(
        checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep)
    logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    logging.info('Starting loop with tbs: %s', FLAGS.train_batch_size)
    for inputs, targets in ds:
        # Inputs are audio vectors.
        inputs.shape.assert_has_rank(2)
        inputs.shape.assert_is_compatible_with([FLAGS.train_batch_size, None])
        targets.shape.assert_has_rank(2)
        targets.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, target_dim])
        train_step(inputs, targets, global_step)
        # Optional print output and save model.
        if global_step % 10 == 0:
            logging.info('step: %i, train loss: %f, train mean abs error: %f',
                         global_step, train_loss.result(), train_mae.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    logging.info('Finished training.')