Пример #1
0
def make_graph(emb, y_onehot, ubn=None, nc=None):
    """Make a graph on data."""
    num_classes = y_onehot.shape[1]
    model = models.get_keras_model(num_classes, ubn, num_clusters=nc)
    logits = model(emb, training=True)
    logits.shape.assert_is_compatible_with(y_onehot.shape)
    for u_op in model.updates:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u_op)

    # Loss.
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)(
        y_true=y_onehot, y_pred=logits)
    tf.summary.scalar('xent_loss', loss)

    # Gradient.
    opt = tf.train.AdamOptimizer(learning_rate=FLAGS.lr,
                                 beta1=0.9,
                                 beta2=0.999,
                                 epsilon=1e-8)
    update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    with tf.control_dependencies(update_ops):
        total_loss = tf.identity(loss)
    var_list = tf.trainable_variables()
    assert var_list
    train_op = opt.minimize(total_loss, tf.train.get_or_create_global_step(),
                            var_list)

    return total_loss, train_op
def _load_keras_model(num_classes, use_batch_normalization, num_clusters,
                      checkpoint_filename):
    """Load the model, and make sure weights have been loaded properly."""
    dummy_input = tf.random.uniform([3, 100, 2048])

    agg_model = models.get_keras_model(num_classes,
                                       use_batch_normalization,
                                       num_clusters=num_clusters)
    checkpoint = tf.train.Checkpoint(model=agg_model)

    o1 = agg_model(dummy_input)
    checkpoint.restore(checkpoint_filename)
    o2 = agg_model(dummy_input)

    assert not np.allclose(o1, o2)

    return agg_model
Пример #3
0
    def test_get_model(self, num_clusters):
        num_classes = 4
        emb = tf.zeros([3, 5, 8])
        y_onehot = tf.one_hot([0, 1, 2], num_classes)

        model = models.get_keras_model(num_classes,
                                       use_batchnorm=True,
                                       num_clusters=num_clusters)

        loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        opt = tf.keras.optimizers.Adam()
        train_loss = tf.keras.metrics.Mean()
        train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        summary_writer = tf.summary.create_file_writer(
            absltest.get_default_test_tmpdir())
        train_step = train_keras.get_train_step(model, loss_obj, opt,
                                                train_loss, train_accuracy,
                                                summary_writer)
        gstep = opt.iterations
        train_step(emb, y_onehot, gstep)
        self.assertEqual(1, gstep)
        train_step(emb, y_onehot, gstep)
        self.assertEqual(2, gstep)
Пример #4
0
def eval_and_report():
  """Eval on voxceleb."""
  logging.info('embedding_name: %s', FLAGS.embedding_name)
  logging.info('Logdir: %s', FLAGS.logdir)
  logging.info('Batch size: %s', FLAGS.batch_size)

  writer = tf.summary.create_file_writer(FLAGS.eval_dir)
  num_classes = len(FLAGS.label_list)
  model = models.get_keras_model(
      num_classes, FLAGS.use_batch_normalization,
      num_clusters=FLAGS.num_clusters, alpha_init=FLAGS.alpha_init)
  checkpoint = tf.train.Checkpoint(model=model)

  for ckpt in tf.train.checkpoints_iterator(
      FLAGS.logdir, timeout=FLAGS.timeout):
    assert 'ckpt-' in ckpt, ckpt
    step = ckpt.split('ckpt-')[-1]
    logging.info('Starting to evaluate step: %s.', step)

    checkpoint.restore(ckpt)

    logging.info('Loaded weights for eval step: %s.', step)

    reader = tf.data.TFRecordDataset
    ds = get_data.get_data(
        file_pattern=FLAGS.file_pattern,
        reader=reader,
        embedding_name=FLAGS.embedding_name,
        embedding_dim=FLAGS.embedding_dimension,
        label_name=FLAGS.label_name,
        label_list=FLAGS.label_list,
        bucket_boundaries=FLAGS.bucket_boundaries,
        bucket_batch_sizes=[FLAGS.batch_size] * (len(FLAGS.bucket_boundaries) + 1),  # pylint:disable=line-too-long
        loop_forever=False,
        shuffle=False)
    logging.info('Got dataset for eval step: %s.', step)
    if FLAGS.take_fixed_data:
      ds = ds.take(FLAGS.take_fixed_data)

    acc_m = tf.keras.metrics.Accuracy()
    xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True)

    logging.info('Starting the ds loop...')
    count, ex_count = 0, 0
    all_logits, all_real = [], []
    s = time.time()
    for emb, y_onehot in ds:
      emb.shape.assert_has_rank(3)
      assert emb.shape[2] == FLAGS.embedding_dimension
      y_onehot.shape.assert_has_rank(2)
      assert y_onehot.shape[1] == len(FLAGS.label_list)

      logits = model(emb, training=False)
      all_logits.extend(logits.numpy()[:, 1])
      all_real.extend(y_onehot.numpy()[:, 1])
      acc_m.update_state(y_true=tf.argmax(y_onehot, 1),
                         y_pred=tf.argmax(logits, 1))
      xent_m.update_state(y_true=y_onehot, y_pred=logits)
      ex_count += logits.shape[0]
      count += 1
      logging.info('Saw %i examples after %i iterations as %.2f secs...',
                   ex_count, count,
                   time.time() - s)
    if FLAGS.calculate_equal_error_rate:
      eer_score = metrics.calculate_eer(all_real, all_logits)
    auc_score = metrics.calculate_auc(all_real, all_logits)
    dprime_score = metrics.dprime_from_auc(auc_score)
    with writer.as_default():
      tf.summary.scalar('accuracy', acc_m.result().numpy(), step=int(step))
      tf.summary.scalar('xent_loss', xent_m.result().numpy(), step=int(step))
      tf.summary.scalar('auc', auc_score, step=int(step))
      tf.summary.scalar('dprime', dprime_score, step=int(step))
      if FLAGS.calculate_equal_error_rate:
        tf.summary.scalar('eer', eer_score, step=int(step))
    logging.info('Done with eval step: %s in %.2f secs.', step, time.time() - s)
Пример #5
0
def get_model(num_classes, ubn=None, nc=None):
    model = models.get_keras_model(num_classes, ubn, num_clusters=nc)
    for u_op in model.updates:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u_op)
    return model
Пример #6
0
def train_and_report(debug=False):
    """Trains the classifier."""
    logging.info('embedding_name: %s', FLAGS.embedding_name)
    logging.info('embedding_dimension: %s', FLAGS.embedding_dimension)
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.train_batch_size)

    reader = tf.data.TFRecordDataset
    ds = get_data.get_data(
        file_pattern=FLAGS.file_pattern,
        reader=reader,
        embedding_name=FLAGS.embedding_name,
        embedding_dim=FLAGS.embedding_dimension,
        label_name=FLAGS.label_name,
        label_list=FLAGS.label_list,
        bucket_boundaries=FLAGS.bucket_boundaries,
        bucket_batch_sizes=[FLAGS.train_batch_size] *
        (len(FLAGS.bucket_boundaries) + 1),  # pylint:disable=line-too-long
        loop_forever=True,
        shuffle=True,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    # Create model, loss, and other objects.
    y_onehot_spec = ds.element_spec[1]
    assert len(y_onehot_spec.shape) == 2
    num_classes = y_onehot_spec.shape[1]
    model = models.get_keras_model(num_classes,
                                   FLAGS.use_batch_normalization,
                                   num_clusters=FLAGS.num_clusters,
                                   alpha_init=FLAGS.alpha_init)
    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss,
                                train_accuracy, summary_writer)
    global_step = opt.iterations
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(checkpoint,
                                         FLAGS.logdir,
                                         max_to_keep=None)
    logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    for emb, y_onehot in ds:
        emb.shape.assert_has_rank(3)
        assert emb.shape[2] == FLAGS.embedding_dimension
        y_onehot.shape.assert_has_rank(2)
        assert y_onehot.shape[1] == len(FLAGS.label_list)

        train_step(emb, y_onehot, global_step)

        # Optional print output and save model.
        if global_step % 10 == 0:
            logging.info('step: %i, train loss: %f, train accuracy: %f',
                         global_step, train_loss.result(),
                         train_accuracy.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    logging.info('Finished training.')
Пример #7
0
def eval_and_report():
    """Eval on voxceleb."""
    logging.info('embedding_name: %s', FLAGS.en)
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.batch_size)

    writer = tf.summary.create_file_writer(FLAGS.eval_dir)
    num_classes = len(FLAGS.label_list)

    for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir,
                                              timeout=FLAGS.timeout):
        assert 'ckpt-' in ckpt, ckpt
        step = ckpt.split('ckpt-')[-1]
        logging.info('Starting to evaluate step: %s.', step)

        model = models.get_keras_model(num_classes,
                                       FLAGS.ubn,
                                       num_clusters=FLAGS.nc)
        model.load_weights(ckpt)

        logging.info('Loaded weights for eval step: %s.', step)

        reader = tf.data.TFRecordDataset
        ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                               reader=reader,
                               embedding_name=FLAGS.en,
                               embedding_dim=FLAGS.ed,
                               preaveraged=False,
                               label_name=FLAGS.label_name,
                               label_list=FLAGS.label_list,
                               batch_size=FLAGS.batch_size,
                               loop_forever=False,
                               shuffle=False)
        logging.info('Got dataset for eval step: %s.', step)
        if FLAGS.take_fixed_data:
            ds = ds.take(FLAGS.take_fixed_data)

        acc_m = tf.keras.metrics.Accuracy()
        xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True)

        logging.info('Starting the ds loop...')
        count, ex_count = 0, 0
        s = time.time()
        for emb, y_onehot in ds:
            emb.shape.assert_has_rank(3)
            assert emb.shape[2] == FLAGS.ed
            y_onehot.shape.assert_has_rank(2)
            assert y_onehot.shape[1] == len(FLAGS.label_list)

            logits = model(emb, training=False)
            acc_m.update_state(y_true=tf.argmax(y_onehot, 1),
                               y_pred=tf.argmax(logits, 1))
            xent_m.update_state(y_true=y_onehot, y_pred=logits)
            ex_count += logits.shape[0]
            count += 1
            logging.info('Saw %i examples after %i iterations as %.2f secs...',
                         ex_count, count,
                         time.time() - s)
        with writer.as_default():
            tf.summary.scalar('accuracy',
                              acc_m.result().numpy(),
                              step=int(step))
            tf.summary.scalar('xent_loss',
                              xent_m.result().numpy(),
                              step=int(step))
        logging.info('Done with eval step: %s in %.2f secs.', step,
                     time.time() - s)