def eval_and_report(): """Eval on voxceleb.""" tf.logging.info('samples_key: %s', FLAGS.samples_key) logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.batch_size) writer = tf.summary.create_file_writer(FLAGS.eval_dir) num_classes = len(FLAGS.label_list) model = models.get_keras_model(num_classes, FLAGS.ubn, num_clusters=FLAGS.nc) checkpoint = tf.train.Checkpoint(model=model) for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir, timeout=FLAGS.timeout): assert 'ckpt-' in ckpt, ckpt step = ckpt.split('ckpt-')[-1] logging.info('Starting to evaluate step: %s.', step) checkpoint.restore(ckpt) logging.info('Loaded weights for eval step: %s.', step) reader = tf.data.TFRecordDataset ds = get_data.get_data(file_pattern=FLAGS.file_pattern, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, label_key=FLAGS.label_key, label_list=FLAGS.label_list, batch_size=FLAGS.batch_size, loop_forever=False, shuffle=False) logging.info('Got dataset for eval step: %s.', step) if FLAGS.take_fixed_data: ds = ds.take(FLAGS.take_fixed_data) acc_m = tf.keras.metrics.Accuracy() xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True) logging.info('Starting the ds loop...') count, ex_count = 0, 0 s = time.time() for wav_samples, y_onehot in ds: wav_samples.shape.assert_is_compatible_with( [None, FLAGS.min_length]) y_onehot.shape.assert_is_compatible_with( [None, len(FLAGS.label_list)]) logits = model(wav_samples, training=False) acc_m.update_state(y_true=tf.argmax(y_onehot, 1), y_pred=tf.argmax(logits, 1)) xent_m.update_state(y_true=y_onehot, y_pred=logits) ex_count += logits.shape[0] count += 1 logging.info('Saw %i examples after %i iterations as %.2f secs...', ex_count, count, time.time() - s) with writer.as_default(): tf.summary.scalar('accuracy', acc_m.result().numpy(), step=int(step)) tf.summary.scalar('xent_loss', xent_m.result().numpy(), step=int(step)) logging.info('Done with eval step: %s in %.2f secs.', step, time.time() - s)
def train_and_report(debug=False): """Trains the classifier.""" tf.logging.info('samples_key: %s', FLAGS.samples_key) tf.logging.info('Logdir: %s', FLAGS.logdir) tf.logging.info('Batch size: %s', FLAGS.train_batch_size) tf.logging.info('label_list: %s', FLAGS.label_list) reader = tf.data.TFRecordDataset ds = get_data.get_data(file_pattern=FLAGS.file_pattern, reader=reader, samples_key=FLAGS.samples_key, min_length=FLAGS.min_length, label_key=FLAGS.label_key, label_list=FLAGS.label_list, batch_size=FLAGS.train_batch_size, loop_forever=True, shuffle=True, shuffle_buffer_size=FLAGS.shuffle_buffer_size) # Create model, loss, and other objects. y_onehot_spec = ds.element_spec[1] assert len(y_onehot_spec.shape) == 2, y_onehot_spec.shape num_classes = y_onehot_spec.shape[1] model = models.get_keras_model(num_classes, input_length=FLAGS.min_length, use_batchnorm=FLAGS.use_batch_normalization, num_clusters=FLAGS.num_clusters, alpha_init=FLAGS.alpha_init) # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True) opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) # Add additional metrics to track. train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_accuracy') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_accuracy, summary_writer) global_step = opt.iterations checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager(checkpoint, FLAGS.logdir, max_to_keep=None) tf.logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return for wav_samples, y_onehot in ds: wav_samples.shape.assert_has_rank(2) wav_samples.shape.assert_is_compatible_with( [FLAGS.train_batch_size, FLAGS.min_length]) y_onehot.shape.assert_is_compatible_with( [FLAGS.train_batch_size, len(FLAGS.label_list)]) train_step(wav_samples, y_onehot, global_step) # Optional print output and save model. if global_step % 10 == 0: tf.logging.info('step: %i, train loss: %f, train accuracy: %f', global_step, train_loss.result(), train_accuracy.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) tf.logging.info('Finished training.')