Пример #1
0
    def build_sub_train_loader(self,
                               n_images,
                               batch_size,
                               num_worker=None,
                               num_replicas=None,
                               rank=None):
        # used for resetting BN running statistics
        if self.__dict__.get("sub_train_%d" % self.active_img_size,
                             None) is None:
            if num_worker is None:
                num_worker = self.train.num_workers

            n_samples = len(self.train.dataset)
            g = torch.Generator()
            g.manual_seed(DataProvider.SUB_SEED)
            rand_indexes = torch.randperm(n_samples, generator=g).tolist()

            new_train_dataset = self.train_dataset(
                self.build_train_transform(image_size=self.active_img_size,
                                           print_log=False))
            chosen_indexes = rand_indexes[:n_images]
            if num_replicas is not None:
                sub_sampler = MyDistributedSampler(
                    new_train_dataset,
                    num_replicas,
                    rank,
                    True,
                    np.array(chosen_indexes),
                )
            else:
                sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                    chosen_indexes)
            sub_data_loader = torch.utils.data.DataLoader(
                new_train_dataset,
                batch_size=batch_size,
                sampler=sub_sampler,
                num_workers=num_worker,
                pin_memory=True,
            )
            self.__dict__["sub_train_%d" % self.active_img_size] = []
            for images, labels in sub_data_loader:
                self.__dict__["sub_train_%d" % self.active_img_size].append(
                    (images, labels))
        return self.__dict__["sub_train_%d" % self.active_img_size]
Пример #2
0
    def __init__(self,
                 save_path=None,
                 train_batch_size=256,
                 test_batch_size=512,
                 valid_size=None,
                 n_worker=32,
                 resize_scale=0.08,
                 distort_color=None,
                 image_size=224,
                 num_replicas=None,
                 rank=None):

        warnings.filterwarnings('ignore')
        self._save_path = save_path

        self.image_size = image_size  # int or list of int
        self.distort_color = 'None' if distort_color is None else distort_color
        self.resize_scale = resize_scale

        self._valid_transform_dict = {}

        # set active_img_size and transoforms
        if not isinstance(self.image_size, int):
            from ofa.utils.my_dataloader import MyDataLoader
            assert isinstance(self.image_size, list)
            self.image_size.sort()  # e.g., 160 -> 224
            MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
            MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)

            for img_size in self.image_size:
                self._valid_transform_dict[
                    img_size] = self.build_valid_transform(img_size)
            self.active_img_size = max(
                self.image_size)  # active resolution for test
            valid_transforms = self._valid_transform_dict[self.active_img_size]
            train_loader_class = MyDataLoader  # randomly sample image size for each batch of training image
        else:
            self.active_img_size = self.image_size
            valid_transforms = self.build_valid_transform()
            train_loader_class = torch.utils.data.DataLoader

        train_dataset = datasets.CIFAR100(
            root='./data',
            train=True,
            download=True,
            transform=self.build_train_transform())
        # build train, valid datasets
        if valid_size is not None:
            if not isinstance(valid_size, int):
                assert isinstance(valid_size, float) and 0 < valid_size < 1
                valid_size = int(len(train_dataset) * valid_size)

            valid_dataset = self.train_dataset(valid_transforms)

            train_indexes, valid_indexes = self.random_sample_valid_set(
                len(train_dataset), valid_size)

            # build_samplers
            if num_replicas is not None:
                train_sampler = MyDistributedSampler(train_dataset,
                                                     num_replicas, rank, True,
                                                     np.array(train_indexes))
                valid_sampler = MyDistributedSampler(valid_dataset,
                                                     num_replicas, rank, True,
                                                     np.array(valid_indexes))
            else:
                train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                    train_indexes)
                valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                    valid_indexes)

            self.train = train_loader_class(
                train_dataset,
                batch_size=train_batch_size,
                sampler=train_sampler,
                num_workers=n_worker,
                pin_memory=True,
            )
            self.valid = torch.utils.data.DataLoader(
                valid_dataset,
                batch_size=test_batch_size,
                sampler=valid_sampler,
                num_workers=n_worker,
                pin_memory=True,
            )
        else:
            if num_replicas is not None:
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset, num_replicas, rank)
                self.train = train_loader_class(train_dataset,
                                                batch_size=train_batch_size,
                                                sampler=train_sampler,
                                                num_workers=n_worker,
                                                pin_memory=True)
            else:
                self.train = train_loader_class(
                    train_dataset,
                    batch_size=train_batch_size,
                    shuffle=True,
                    num_workers=n_worker,
                    pin_memory=True,
                )
            self.valid = None

        test_dataset = self.test_dataset(valid_transforms)
        if num_replicas is not None:
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset, num_replicas, rank)
            self.test = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=test_batch_size,
                sampler=test_sampler,
                num_workers=n_worker,
                pin_memory=True,
            )
        else:
            self.test = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=test_batch_size,
                shuffle=True,
                num_workers=n_worker,
                pin_memory=True,
            )

        if self.valid is None:
            self.valid = self.test