def _dataset():
    """Returns the `MS-ASL` dataset example for the word `dance`.

    The dataset example corresponding to the label with index 84 and the signer with index 347 shall be returned.

    Returns:
        The dataset example for the word `dance`.
    """
    dataset = tf_record_dataset(DatasetName.MSASL, DatasetType.TRAIN)
    dataset = dataset.batch(1)
    dataset = dataset.map(transform_for_prediction)
    dataset = dataset.unbatch()
    dataset = dataset.filter(lambda frames, label, signer: tf.math.equal(label, 84) and tf.math.equal(signer, 347))
    dataset = dataset.batch(1)
    return dataset.take(1)
示例#2
0
def inspect_dataset(dataset_name: DatasetName, dataset_type: DatasetType,
                    inspect_fn, batch_size=1, skip_count=0, take_count=1):
    """Inspects a given number of dataset examples according to a passed inspection function.

    Arguments:
        dataset_name: The name of the dataset to inspect (one of `DatasetName`).
        dataset_type: The type of the dataset to inspect (one of `DatasetType`).
        inspect_fn: The function to apply to a sequence of frames.
        batch_size: The number of consecutive dataset examples to combine into a batch.
        skip_count: The number of dataset examples to skip.
        take_count: The maximum number of dataset examples to fetch.
    """
    dataset = tf_record_dataset(dataset_name, dataset_type)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(transform_for_inspection)
    if skip_count:
        dataset = dataset.skip(skip_count)
    for frames_batch, labels, signers in dataset.take(take_count):
        frames_batch = frames_batch.numpy()
        for frames in frames_batch:
            inspect_fn(frames)
def _train_dataset():
    train_dataset = tf_record_dataset(DatasetName.MSASL, DatasetType.TRAIN)
    train_dataset = train_dataset.shuffle(2048)
    train_dataset = train_dataset.batch(32)
    train_dataset = train_dataset.map(transform_for_msasl_model)
    return train_dataset.prefetch(2)
示例#4
0
def _validation_dataset():
    validation_dataset = tf_record_dataset(DatasetName.SIGNUM,
                                           DatasetType.VALIDATION)
    validation_dataset = validation_dataset.batch(32)
    validation_dataset = validation_dataset.map(transform_for_signum_model)
    return validation_dataset.prefetch(2)
示例#5
0
def _train_dataset():
    train_dataset = tf_record_dataset(DatasetName.SIGNUM, DatasetType.TRAIN)
    train_dataset = train_dataset.shuffle(2048)
    train_dataset = train_dataset.batch(32)
    train_dataset = train_dataset.map(transform_for_signum_model)
    return train_dataset.prefetch(2)
def _test_dataset():
    test_dataset = tf_record_dataset(DatasetName.SIGNUM, DatasetType.TEST)
    test_dataset = test_dataset.batch(32)
    test_dataset = test_dataset.map(transform_for_signum_model)
    return test_dataset.prefetch(2)