Пример #1
0
class DistributedDataLoader(object):
    def initialize(self, opt):
        # print("Use distributed dataloader")
        self.dataset = ComposeDataset(opt)

        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        self.train_sampler = DistributedSampler(self.dataset, world_size, rank)

        num_workers = opt.nThreads
        assert opt.batchSize % world_size == 0
        batch_size = opt.batchSize // world_size
        shuffle = False
        drop_last = opt.isTrain

        self.data_loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            sampler=self.train_sampler,
            drop_last=drop_last,
            pin_memory=False)

    def load_data(self):
        return self.data_loader

    def shuffle_data(self):
        self.dataset.shuffle_data()

    def __len__(self):
        return len(self.dataset)
Пример #2
0
 def initialize(self, opt):
     self.opt = opt
     self.dataset = ComposeDataset(opt)
     self.dataloader = DataLoader(self.dataset,
                                  batch_size=opt.batchSize,
                                  shuffle=not opt.serial_batches,
                                  num_workers=int(opt.nThreads),
                                  drop_last=opt.isTrain)
Пример #3
0
class CustomDatasetDataLoader(object):
    @property
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        self.opt = opt
        self.dataset = ComposeDataset(opt)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=opt.batchSize,
                                     shuffle=not opt.serial_batches,
                                     num_workers=int(opt.nThreads),
                                     drop_last=opt.isTrain)

    def load_data(self):
        return self.dataloader

    def shuffle_data(self):
        self.dataset.shuffle_data()

    def __len__(self):
        return len(self.dataset)
def CreateDataset(opt):
    dataset = None
    if opt.dataset_mode == 'comp_decomp_unaligned':
        from data.compose_dataset import ComposeDataset
        dataset = ComposeDataset()
    elif opt.dataset_mode == 'comp_decomp_aligned':
        from data.compose_dataset import ComposeAlignedDataset
        dataset = ComposeAlignedDataset()

    elif opt.dataset_mode == 'AFN':
        from data.AFN_dataset import AFNDataset
        dataset = AFNDataset()
    elif opt.dataset_mode == 'AFNCompose':
        from data.AFN_compose_dataset import AFNComposeDataset
        dataset = AFNComposeDataset()

    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset