示例#1
0
def split_dataset(
    dataset: tf.data.Dataset, dataset_size: int, train_ratio: float,
    validation_ratio: float
) -> typing.Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    assert (train_ratio + validation_ratio) < 1

    train_count = int(dataset_size * train_ratio)
    validation_count = int(dataset_size * validation_ratio)
    test_count = dataset_size - (train_count + validation_count)

    dataset = dataset.shuffle(dataset_size)

    train_dataset = dataset.take(train_count)
    validation_dataset = dataset.skip(train_count).take(validation_count)
    test_dataset = dataset.skip(validation_count +
                                train_count).take(test_count)

    return train_dataset, validation_dataset, test_dataset
示例#2
0
def split_dataset(
    dataset: tf.data.Dataset,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    train_size, val_size, test_size = get_split_sizes(TOTAL_SAMPLES)
    train = dataset.take(train_size)
    test = dataset.skip(train_size)
    val = test.skip(val_size)
    test = test.take(test_size)
    return train, val, test
示例#3
0
def tf_train_test_split(dataset: tf.data.Dataset,
                        dataset_size=None,
                        test_frac=0.2):
    if not dataset_size:
        dataset_size = tf.data.experimental.cardinality(dataset).numpy()

    train_size = int((1 - test_frac) * dataset_size)

    train_dataset = dataset.take(train_size)
    test_dataset = dataset.skip(train_size)

    return train_dataset, test_dataset
示例#4
0
def split_dataset(ds: tf.data.Dataset):
    val_num = 500
    test_num = 500

    train_dataset = ds.skip(val_num + test_num)

    test_val_ds = ds.take(val_num + test_num)

    test_dataset = test_val_ds.take(test_num)
    val_dataset = test_val_ds.skip(test_num)

    print("dataset splitted")
    return train_dataset, test_dataset, val_dataset
示例#5
0
def dataset_split(dataset: tf.data.Dataset,
                  split_fraction: float,
                  fold: int = 0) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """Splits the dataset into one chunk with split_fraction many elements of
    the original dataset and another chunk with size (1 - split_fraction)
    elements.

    Args:
        dataset (tf.data.Dataset): Dataset to be splitted.
        split_fraction (float): Fraction of the dataset split.
        fold (int): Which fold of the dataset, the validation set should be.

    Returns:
        Tuple[tf.data.Dataset, tf.data.Dataset]: Splitted datasets tuple.
    """
    split_size = int(len(dataset) * split_fraction)
    offset_idx = fold * split_size
    val_dataset = dataset.skip(offset_idx).take(split_size)
    first_train_folds = dataset.take(offset_idx)
    last_train_folds = dataset.skip(offset_idx + split_size)
    train_dataset = first_train_folds.concatenate(last_train_folds)
    return train_dataset, val_dataset
示例#6
0
def split_train_valid(
    tf_dataset: tf.data.Dataset,
    train_size: int = 55000,
    valid_size: int = 5000,
    shuffle=True,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    tf_dataset = tf_dataset.shuffle(train_size + valid_size, seed=SEED)

    train_dataset = tf_dataset.take(train_size)
    valid_dataset = tf_dataset.skip(train_size)

    if shuffle:
        train_dataset = train_dataset.shuffle(train_size, seed=SEED)
        valid_dataset = valid_dataset.shuffle(valid_size, seed=SEED)

    return train_dataset, valid_dataset
示例#7
0
    def devide_train_validation(
        cls, dataset: tf.data.Dataset, length: int, ratio: float
    ) -> Tuple[Tuple[tf.data.Dataset, int], Tuple[tf.data.Dataset, int]]:
        """学習用と検証用のデータセットに分割する。
        
        Arguments:
            dataset -- 分割するデータセット
            length -- データセットの長さ
            ratio -- 学習用データセットの割合
        
        Returns:
            (Dataset,int),(Dataset,int) -- (学習用のデータセット、データセットのサイズ),(検証用データセット,データセットのサイズ)
        """

        train_size = int(length * ratio)
        validation_size = length - train_size
        train_set = dataset.take(train_size).skip(validation_size)
        validation_set = dataset.skip(train_size).take(validation_size)
        return ((train_set, train_size), (validation_set, validation_size))
示例#8
0
def split_ds(ds: tf.data.Dataset,
             val_percentage=None,
             test_percentage=None,
             buffer_size=None):
    val_percentage = val_percentage or 0
    test_percentage = test_percentage or 0
    buffer_size = buffer_size or 128 * 128
    if val_percentage < 0 or val_percentage >= 1.0:
        raise ValueError("val_percentage must be between (0,1)")
    if test_percentage < 0 or test_percentage >= 1.0:
        raise ValueError("test_percentage must be between (0,1)")
    if (val_percentage + test_percentage) >= 1.0:
        raise ValueError(
            "val_percentage+test_percentage must be between (0,1)")

    full_ds_size = len(ds)
    print("Full size: {0}".format(full_ds_size))
    if val_percentage == 0 and test_percentage == 0:
        print("No split returning ds shuffled")
        return ds.shuffle(buffer_size, reshuffle_each_iteration=False)
    elif val_percentage != 0 and test_percentage == 0:
        val_ds_size = int(full_ds_size * val_percentage)
        train_ds_size = full_ds_size - val_ds_size

        ds = ds.shuffle(buffer_size, reshuffle_each_iteration=False)

        train_ds = ds.take(train_ds_size)
        val_ds = ds.skip(train_ds_size)

        print("Train size: {0}".format(len(train_ds)))
        print("Val size: {0}".format(len(val_ds)))
        return train_ds, val_ds

    elif val_percentage == 0 and test_percentage != 0:
        test_ds_size = int(full_ds_size * test_percentage)
        train_ds_size = full_ds_size - test_ds_size

        ds = ds.shuffle(buffer_size, reshuffle_each_iteration=False)

        train_ds = ds.take(train_ds_size)
        test_ds = ds.skip(train_ds_size)

        print("Train size: {0}".format(len(train_ds)))
        print("Test size: {0}".format(len(test_ds)))
        return train_ds, test_ds
    else:
        val_ds_size = int(full_ds_size * val_percentage)
        test_ds_size = int(full_ds_size * test_percentage)
        train_ds_size = full_ds_size - test_ds_size - val_ds_size

        ds = ds.shuffle(buffer_size, reshuffle_each_iteration=False)

        train_ds = ds.take(train_ds_size)
        remaining = ds.skip(train_ds_size)

        test_ds = remaining.take(test_ds_size)
        val_ds = remaining.skip(test_ds_size)

        print("Train size: {0}".format(len(train_ds)))
        print("Val size: {0}".format(len(val_ds)))
        print("Test size: {0}".format(len(test_ds)))
        return train_ds, val_ds, test_ds,
示例#9
0
def split_dataset(dataset: tf.data.Dataset, left_size: int, buffer_size: int=1000) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:
    dataset.shuffle(buffer_size=buffer_size)
    return dataset.take(left_size), dataset.skip(left_size)