コード例 #1
0
    def __init__(self,
                 dataset,
                 partition_number,
                 horovod,
                 batch_size=128,
                 should_shuffle=True,
                 ignore_last=False):
        self.should_shuffle = should_shuffle

        # store our dataset as well
        partition_size = dataset.size // horovod.size()
        if partition_number == horovod.size() - 1:
            self.partition = (partition_size * partition_number, dataset.size)
        else:
            self.partition = (partition_size * partition_number,
                              partition_size * (partition_number + 1))
        self.dataset = dataset

        self.ignore_last = ignore_last
        self.batch_size = batch_size
        self.total_size = self.partition[1] - self.partition[0]
        self.steps_per_epoch = int(math.ceil(self.total_size /
                                             self.batch_size))
        self.index = self.partition[0]
        self.max_index = self.partition[1]
        self.epoch = 0
        self.step = 0
        if should_shuffle:
            shuffle_inplace(self.dataset.get_dataset())
コード例 #2
0
    def __init__(self, dataset, batch_size=128, should_shuffle=True,
                 ignore_last=False):
        self.should_shuffle = should_shuffle

        # store our dataset as well
        self.dataset = dataset
        if should_shuffle:
            shuffle_inplace(self.dataset.get_dataset())

        self.ignore_last = ignore_last
        self.batch_size = batch_size
        self.total_size = dataset.size
        self.steps_per_epoch = int(math.ceil(self.total_size / self.batch_size))
        self.index = 0
コード例 #3
0
 def reset(self):
     self.index = 0
     self.step = 0
     if self.should_shuffle:
         shuffle_inplace(self.dataset.get_dataset())