示例#1
0
def eval_and_report():
    """Eval on voxceleb."""
    tf.logging.info('samples_key: %s', FLAGS.samples_key)
    logging.info('Logdir: %s', FLAGS.logdir)
    logging.info('Batch size: %s', FLAGS.batch_size)

    writer = tf.summary.create_file_writer(FLAGS.eval_dir)
    num_classes = len(FLAGS.label_list)
    model = models.get_keras_model(num_classes,
                                   FLAGS.ubn,
                                   num_clusters=FLAGS.nc)
    checkpoint = tf.train.Checkpoint(model=model)

    for ckpt in tf.train.checkpoints_iterator(FLAGS.logdir,
                                              timeout=FLAGS.timeout):
        assert 'ckpt-' in ckpt, ckpt
        step = ckpt.split('ckpt-')[-1]
        logging.info('Starting to evaluate step: %s.', step)

        checkpoint.restore(ckpt)

        logging.info('Loaded weights for eval step: %s.', step)

        reader = tf.data.TFRecordDataset
        ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                               reader=reader,
                               samples_key=FLAGS.samples_key,
                               min_length=FLAGS.min_length,
                               label_key=FLAGS.label_key,
                               label_list=FLAGS.label_list,
                               batch_size=FLAGS.batch_size,
                               loop_forever=False,
                               shuffle=False)
        logging.info('Got dataset for eval step: %s.', step)
        if FLAGS.take_fixed_data:
            ds = ds.take(FLAGS.take_fixed_data)

        acc_m = tf.keras.metrics.Accuracy()
        xent_m = tf.keras.metrics.CategoricalCrossentropy(from_logits=True)

        logging.info('Starting the ds loop...')
        count, ex_count = 0, 0
        s = time.time()
        for wav_samples, y_onehot in ds:
            wav_samples.shape.assert_is_compatible_with(
                [None, FLAGS.min_length])
            y_onehot.shape.assert_is_compatible_with(
                [None, len(FLAGS.label_list)])

            logits = model(wav_samples, training=False)
            acc_m.update_state(y_true=tf.argmax(y_onehot, 1),
                               y_pred=tf.argmax(logits, 1))
            xent_m.update_state(y_true=y_onehot, y_pred=logits)
            ex_count += logits.shape[0]
            count += 1
            logging.info('Saw %i examples after %i iterations as %.2f secs...',
                         ex_count, count,
                         time.time() - s)
        with writer.as_default():
            tf.summary.scalar('accuracy',
                              acc_m.result().numpy(),
                              step=int(step))
            tf.summary.scalar('xent_loss',
                              xent_m.result().numpy(),
                              step=int(step))
        logging.info('Done with eval step: %s in %.2f secs.', step,
                     time.time() - s)
def train_and_report(debug=False):
    """Trains the classifier."""
    tf.logging.info('samples_key: %s', FLAGS.samples_key)
    tf.logging.info('Logdir: %s', FLAGS.logdir)
    tf.logging.info('Batch size: %s', FLAGS.train_batch_size)
    tf.logging.info('label_list: %s', FLAGS.label_list)

    reader = tf.data.TFRecordDataset
    ds = get_data.get_data(file_pattern=FLAGS.file_pattern,
                           reader=reader,
                           samples_key=FLAGS.samples_key,
                           min_length=FLAGS.min_length,
                           label_key=FLAGS.label_key,
                           label_list=FLAGS.label_list,
                           batch_size=FLAGS.train_batch_size,
                           loop_forever=True,
                           shuffle=True,
                           shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    # Create model, loss, and other objects.
    y_onehot_spec = ds.element_spec[1]
    assert len(y_onehot_spec.shape) == 2, y_onehot_spec.shape
    num_classes = y_onehot_spec.shape[1]
    model = models.get_keras_model(num_classes,
                                   input_length=FLAGS.min_length,
                                   use_batchnorm=FLAGS.use_batch_normalization,
                                   num_clusters=FLAGS.num_clusters,
                                   alpha_init=FLAGS.alpha_init)
    # Define loss and optimizer hyparameters.
    loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   epsilon=1e-8)
    # Add additional metrics to track.
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir)
    train_step = get_train_step(model, loss_obj, opt, train_loss,
                                train_accuracy, summary_writer)
    global_step = opt.iterations
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    manager = tf.train.CheckpointManager(checkpoint,
                                         FLAGS.logdir,
                                         max_to_keep=None)
    tf.logging.info('Checkpoint prefix: %s', FLAGS.logdir)
    checkpoint.restore(manager.latest_checkpoint)

    if debug: return
    for wav_samples, y_onehot in ds:
        wav_samples.shape.assert_has_rank(2)
        wav_samples.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size, FLAGS.min_length])
        y_onehot.shape.assert_is_compatible_with(
            [FLAGS.train_batch_size,
             len(FLAGS.label_list)])

        train_step(wav_samples, y_onehot, global_step)

        # Optional print output and save model.
        if global_step % 10 == 0:
            tf.logging.info('step: %i, train loss: %f, train accuracy: %f',
                            global_step, train_loss.result(),
                            train_accuracy.result())
        if global_step % FLAGS.measurement_store_interval == 0:
            manager.save(checkpoint_number=global_step)

    manager.save(checkpoint_number=global_step)
    tf.logging.info('Finished training.')