def make_graph(emb, y_onehot, ubn=None, nc=None): """Make a graph on data.""" num_classes = y_onehot.shape[1] model = models.get_keras_model(num_classes, ubn, num_clusters=nc) logits = model(emb, training=True) logits.shape.assert_is_compatible_with(y_onehot.shape) for u_op in model.updates: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u_op) # Loss. loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)( y_true=y_onehot, y_pred=logits) tf.summary.scalar('xent_loss', loss) # Gradient. opt = tf.train.AdamOptimizer(learning_rate=FLAGS.lr, beta1=0.9, beta2=0.999, epsilon=1e-8) update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) with tf.control_dependencies(update_ops): total_loss = tf.identity(loss) var_list = tf.trainable_variables() assert var_list train_op = opt.minimize(total_loss, tf.train.get_or_create_global_step(), var_list) return total_loss, train_op
def _load_keras_model(num_classes, use_batch_normalization, num_clusters, checkpoint_filename): """Load the model, and make sure weights have been loaded properly.""" dummy_input = tf.random.uniform([3, 100, 2048]) agg_model = models.get_keras_model(num_classes, use_batch_normalization, num_clusters=num_clusters) checkpoint = tf.train.Checkpoint(model=agg_model) o1 = agg_model(dummy_input) checkpoint.restore(checkpoint_filename) o2 = agg_model(dummy_input) assert not np.allclose(o1, o2) return agg_model
def test_get_model(self, num_clusters): num_classes = 4 emb = tf.zeros([3, 5, 8]) y_onehot = tf.one_hot([0, 1, 2], num_classes) model = models.get_keras_model(num_classes, use_batchnorm=True, num_clusters=num_clusters) loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True) opt = tf.keras.optimizers.Adam() train_loss = tf.keras.metrics.Mean() train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() summary_writer = tf.summary.create_file_writer( absltest.get_default_test_tmpdir()) train_step = train_keras.get_train_step(model, loss_obj, opt, train_loss, train_accuracy, summary_writer) gstep = opt.iterations train_step(emb, y_onehot, gstep) self.assertEqual(1, gstep) train_step(emb, y_onehot, gstep) self.assertEqual(2, gstep)
def eval_and_report(): """Eval on voxceleb.""" logging.info('embedding_name: %s', FLAGS.embedding_name) logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.batch_size) writer = tf.summary.create_file_writer(FLAGS.eval_dir) num_classes = len(FLAGS.label_list) model = models.get_keras_model( num_classes, FLAGS.use_batch_normalization, num_clusters=FLAGS.num_clusters, alpha_init=FLAGS.alpha_init) checkpoint = tf.train.Checkpoint(model=model) for ckpt in tf.train.checkpoints_iterator( FLAGS.logdir, timeout=FLAGS.timeout): assert 'ckpt-' in ckpt, ckpt step = ckpt.split('ckpt-')[-1] logging.info('Starting to evaluate step: %s.', step) checkpoint.restore(ckpt) logging.info('Loaded weights for eval step: %s.', step) reader = tf.data.TFRecordDataset ds = get_data.get_data( file_pattern=FLAGS.file_pattern, reader=reader, embedding_name=FLAGS.embedding_name, embedding_dim=FLAGS.embedding_dimension, label_name=FLAGS.label_name, label_list=FLAGS.label_list, bucket_boundaries=FLAGS.bucket_boundaries, bucket_batch_sizes=[FLAGS.batch_size] * (len(FLAGS.bucket_boundaries) + 1), # pylint:disable=line-too-long loop_forever=False, shuffle=False) logging.info('Got dataset for eval step: %s.', step) if FLAGS.take_fixed_data: ds = ds.take(FLAGS.take_fixed_data) acc_m = tf.keras.metrics.Accuracy() xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True) logging.info('Starting the ds loop...') count, ex_count = 0, 0 all_logits, all_real = [], [] s = time.time() for emb, y_onehot in ds: emb.shape.assert_has_rank(3) assert emb.shape[2] == FLAGS.embedding_dimension y_onehot.shape.assert_has_rank(2) assert y_onehot.shape[1] == len(FLAGS.label_list) logits = model(emb, training=False) all_logits.extend(logits.numpy()[:, 1]) all_real.extend(y_onehot.numpy()[:, 1]) acc_m.update_state(y_true=tf.argmax(y_onehot, 1), y_pred=tf.argmax(logits, 1)) xent_m.update_state(y_true=y_onehot, y_pred=logits) ex_count += logits.shape[0] count += 1 logging.info('Saw %i examples after %i iterations as %.2f secs...', ex_count, count, time.time() - s) if FLAGS.calculate_equal_error_rate: eer_score = metrics.calculate_eer(all_real, all_logits) auc_score = metrics.calculate_auc(all_real, all_logits) dprime_score = metrics.dprime_from_auc(auc_score) with writer.as_default(): tf.summary.scalar('accuracy', acc_m.result().numpy(), step=int(step)) tf.summary.scalar('xent_loss', xent_m.result().numpy(), step=int(step)) tf.summary.scalar('auc', auc_score, step=int(step)) tf.summary.scalar('dprime', dprime_score, step=int(step)) if FLAGS.calculate_equal_error_rate: tf.summary.scalar('eer', eer_score, step=int(step)) logging.info('Done with eval step: %s in %.2f secs.', step, time.time() - s)
def get_model(num_classes, ubn=None, nc=None): model = models.get_keras_model(num_classes, ubn, num_clusters=nc) for u_op in model.updates: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u_op) return model
def train_and_report(debug=False): """Trains the classifier.""" logging.info('embedding_name: %s', FLAGS.embedding_name) logging.info('embedding_dimension: %s', FLAGS.embedding_dimension) logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.train_batch_size) reader = tf.data.TFRecordDataset ds = get_data.get_data( file_pattern=FLAGS.file_pattern, reader=reader, embedding_name=FLAGS.embedding_name, embedding_dim=FLAGS.embedding_dimension, label_name=FLAGS.label_name, label_list=FLAGS.label_list, bucket_boundaries=FLAGS.bucket_boundaries, bucket_batch_sizes=[FLAGS.train_batch_size] * (len(FLAGS.bucket_boundaries) + 1), # pylint:disable=line-too-long loop_forever=True, shuffle=True, shuffle_buffer_size=FLAGS.shuffle_buffer_size) # Create model, loss, and other objects. y_onehot_spec = ds.element_spec[1] assert len(y_onehot_spec.shape) == 2 num_classes = y_onehot_spec.shape[1] model = models.get_keras_model(num_classes, FLAGS.use_batch_normalization, num_clusters=FLAGS.num_clusters, alpha_init=FLAGS.alpha_init) # Define loss and optimizer hyparameters. loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True) opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8) # Add additional metrics to track. train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_accuracy') summary_writer = tf.summary.create_file_writer(FLAGS.logdir) train_step = get_train_step(model, loss_obj, opt, train_loss, train_accuracy, summary_writer) global_step = opt.iterations checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) manager = tf.train.CheckpointManager(checkpoint, FLAGS.logdir, max_to_keep=None) logging.info('Checkpoint prefix: %s', FLAGS.logdir) checkpoint.restore(manager.latest_checkpoint) if debug: return for emb, y_onehot in ds: emb.shape.assert_has_rank(3) assert emb.shape[2] == FLAGS.embedding_dimension y_onehot.shape.assert_has_rank(2) assert y_onehot.shape[1] == len(FLAGS.label_list) train_step(emb, y_onehot, global_step) # Optional print output and save model. if global_step % 10 == 0: logging.info('step: %i, train loss: %f, train accuracy: %f', global_step, train_loss.result(), train_accuracy.result()) if global_step % FLAGS.measurement_store_interval == 0: manager.save(checkpoint_number=global_step) manager.save(checkpoint_number=global_step) logging.info('Finished training.')
def eval_and_report(): """Eval on voxceleb.""" logging.info('embedding_name: %s', FLAGS.en) logging.info('Logdir: %s', FLAGS.logdir) logging.info('Batch size: %s', FLAGS.batch_size) writer = tf.summary.create_file_writer(FLAGS.eval_dir) num_classes = len(FLAGS.label_list) for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir, timeout=FLAGS.timeout): assert 'ckpt-' in ckpt, ckpt step = ckpt.split('ckpt-')[-1] logging.info('Starting to evaluate step: %s.', step) model = models.get_keras_model(num_classes, FLAGS.ubn, num_clusters=FLAGS.nc) model.load_weights(ckpt) logging.info('Loaded weights for eval step: %s.', step) reader = tf.data.TFRecordDataset ds = get_data.get_data(file_pattern=FLAGS.file_pattern, reader=reader, embedding_name=FLAGS.en, embedding_dim=FLAGS.ed, preaveraged=False, label_name=FLAGS.label_name, label_list=FLAGS.label_list, batch_size=FLAGS.batch_size, loop_forever=False, shuffle=False) logging.info('Got dataset for eval step: %s.', step) if FLAGS.take_fixed_data: ds = ds.take(FLAGS.take_fixed_data) acc_m = tf.keras.metrics.Accuracy() xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True) logging.info('Starting the ds loop...') count, ex_count = 0, 0 s = time.time() for emb, y_onehot in ds: emb.shape.assert_has_rank(3) assert emb.shape[2] == FLAGS.ed y_onehot.shape.assert_has_rank(2) assert y_onehot.shape[1] == len(FLAGS.label_list) logits = model(emb, training=False) acc_m.update_state(y_true=tf.argmax(y_onehot, 1), y_pred=tf.argmax(logits, 1)) xent_m.update_state(y_true=y_onehot, y_pred=logits) ex_count += logits.shape[0] count += 1 logging.info('Saw %i examples after %i iterations as %.2f secs...', ex_count, count, time.time() - s) with writer.as_default(): tf.summary.scalar('accuracy', acc_m.result().numpy(), step=int(step)) tf.summary.scalar('xent_loss', xent_m.result().numpy(), step=int(step)) logging.info('Done with eval step: %s in %.2f secs.', step, time.time() - s)