def test_get_data_max_len(self): bs = 3 max_samples_length = 3 assert max_samples_length < self.len ds = get_data.get_data(file_patterns=self.precomputed_file_pattern, output_dimension=self.output_dim, reader=tf.data.TFRecordDataset, samples_key=self.samples_key, target_key=self.target_key, batch_size=bs, loop_forever=False, shuffle=True, shuffle_buffer_size=2, min_samples_length=1, max_samples_length=max_samples_length) self.assertLen(ds.element_spec, 2) # Test that one element of the input pipeline can be successfully read. # Also, test that the shape is truncated to the shortest in the minibatch. has_elements = False for wav_samples, targets in ds: self.assertEqual(wav_samples.shape, [bs, max_samples_length]) self.assertEqual(targets.shape, [bs, self.output_dim]) has_elements = True break self.assertTrue(has_elements)
def test_get_data_out_len(self): bs = 3 ds = get_data.get_data(file_patterns=self.precomputed_file_pattern, output_dimension=self.output_dim, reader=tf.data.TFRecordDataset, samples_key=self.samples_key, target_key=self.target_key, batch_size=bs, loop_forever=False, shuffle=True, shuffle_buffer_size=2, label_key='label') self.assertLen(ds.element_spec, 3)
def _get_ds(file_patterns, step): """Gets a tf.Dataset for a file.""" ds = get_data.get_data(file_patterns=file_patterns, reader=tf.data.TFRecordDataset, samples_key=AUDIO_KEY_, batch_size=FLAGS.eval_batch_size, loop_forever=False, shuffle=False, target_key=FLAGS.target_key, label_key=FLAGS.label_key, speaker_id_key=FLAGS.speaker_id_key, samples_are_float=False, max_samples_length=None) logging.info('Got dataset for eval step: %s.', step) if FLAGS.take_fixed_data: ds = ds.take(FLAGS.take_fixed_data) return ds
def test_get_data_min_len(self): bs = 3 min_samples_length = 1000 ds = get_data.get_data(file_patterns=self.precomputed_file_pattern, output_dimension=self.output_dim, reader=tf.data.TFRecordDataset, samples_key=self.samples_key, target_key=self.target_key, batch_size=bs, loop_forever=False, shuffle=True, shuffle_buffer_size=2, min_samples_length=min_samples_length) has_elements = False for _ in ds: has_elements = True break self.assertFalse(has_elements)
def train_and_report(debug=False, target_dim=1024): """Trains the classifier.""" logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.train_batch_size) reader = tf.data.TFRecordDataset target_key = FLAGS.target_key ds = get_data.get_data(file_patterns=FLAGS.file_patterns, output_dimension=target_dim, reader=reader, samples_key=FLAGS.samples_key, target_key=target_key, batch_size=FLAGS.train_batch_size, loop_forever=True, shuffle=True, max_samples_length=FLAGS.max_sample_length, shuffle_buffer_size=FLAGS.shuffle_buffer_size, samples_are_float=True) assert len(ds.element_spec) == 2, ds.element_spec ds.element_spec[0].shape.assert_has_rank(2) # audio samples ds.element_spec[1].shape.assert_has_rank(2) # teacher embeddings output_dimension = ds.element_spec[1].shape[1] assert output_dimension == target_dim, (output_dimension, target_dim) # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.MeanSquaredError(name='mse_loss') opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) global_step = opt.iterations # Create model, loss, and other objects. model = models.get_keras_model(model_type=FLAGS.model_type, frame_hop=FLAGS.frame_hop) assert model.trainable_variables # Add additional metrics to track. train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss') train_mae = tf.keras.metrics.MeanAbsoluteError(name='train_mae') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_mae, summary_writer) checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager( checkpoint, FLAGS.logdir, max_to_keep=FLAGS.checkpoint_max_to_keep) logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return logging.info('Starting loop with tbs: %s', FLAGS.train_batch_size) for inputs, targets in ds: # Inputs are audio vectors. inputs.shape.assert_has_rank(2) inputs.shape.assert_is_compatible_with([FLAGS.train_batch_size, None]) targets.shape.assert_has_rank(2) targets.shape.assert_is_compatible_with( [FLAGS.train_batch_size, target_dim]) train_step(inputs, targets, global_step) # Optional print output and save model. if global_step % 10 == 0: logging.info('step: %i, train loss: %f, train mean abs error: %f', global_step, train_loss.result(), train_mae.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) logging.info('Finished training.')