예제 #1
0
    def check_unsupported_shard_sizes(self):
        for modality in self.config.modalities:

            shard_size = self.config.get_modality_shard_size(modality)
            max_shard_size = int_ceil(shard_size)
            initial_shard_size = int_floor(shard_size)
            if max_shard_size != initial_shard_size:
                raise ValueError("max_shard_size != initial_shard_size : "
                                 "SubsetLoader doesn't support this case yet.")
예제 #2
0
    def make_tf_datasets_splits(self,
                                pattern: Pattern,
                                split: float,
                                subset_folders: List[str] = None,
                                ) -> Tuple[tf.data.Dataset, Optional[tf.data.Dataset]]:
        if split <= 0.0 or split >= 1.0:
            raise ValueError("Split must be strictly between 0.0 and 1.0, found {}.".format(split))

        if subset_folders is None:
            subset_folders = self.subset_folders
        subset_folders = copy.copy(subset_folders)

        if len(subset_folders) == 1:
            return self.make_tf_dataset(pattern, subset_folders), None

        train_count = int_ceil(len(subset_folders) * split)
        random.shuffle(subset_folders)

        if train_count == len(subset_folders):
            train_count = len(subset_folders) - 1

        train_dataset = self.make_tf_dataset(pattern, subset_folders[:train_count])
        validation_dataset = self.make_tf_dataset(pattern, subset_folders[train_count:])
        return train_dataset, validation_dataset
예제 #3
0
def get_shard_count(sample_length: int, shard_size: int) -> int:
    shard_count = 1 + int_ceil((sample_length - 1) / shard_size)
    return max(2, shard_count)
예제 #4
0
 def get_modality_max_shard_size(self, modality: Modality) -> int:
     return int_ceil(self.get_modality_shard_size(modality))
예제 #5
0
def get_max_frame_count(duration: float, frequency: Union[int, float]):
    return int_ceil(duration * frequency, EPSILON)