示例#1
0
    def build_train_transform(self, image_size=None, print_log=True):
        if image_size is None:
            image_size = self.image_size
        if print_log:
            print('Color jitter: %s, resize_scale: %s, img_size: %s' %
                  (self.distort_color, self.resize_scale, image_size))

        if isinstance(image_size, list):
            resize_transform_class = MyRandomResizedCrop
            print(
                'Use MyRandomResizedCrop: %s, \t %s' %
                MyRandomResizedCrop.get_candidate_image_size(),
                'sync=%s, continuous=%s' %
                (MyRandomResizedCrop.SYNC_DISTRIBUTED,
                 MyRandomResizedCrop.CONTINUOUS))
        else:
            resize_transform_class = transforms.RandomResizedCrop
        if self.subsample == 1:
            # random_resize_crop -> random_horizontal_flip
            train_transforms = [
                resize_transform_class(image_size,
                                       scale=(self.resize_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
            ]

            # color augmentation (optional)
            color_transform = None
            if self.distort_color == 'torch':
                color_transform = transforms.ColorJitter(brightness=0.4,
                                                         contrast=0.4,
                                                         saturation=0.4,
                                                         hue=0.1)
            elif self.distort_color == 'tf':
                color_transform = transforms.ColorJitter(brightness=32. / 255.,
                                                         saturation=0.5)
            if color_transform is not None:
                train_transforms.append(color_transform)
        else:
            train_transforms = [
                transforms.Resize(int(math.ceil(image_size / 0.875))),
                transforms.CenterCrop(image_size),
            ]

        train_transforms += [
            transforms.ToTensor(),
            self.normalize,
        ]

        train_transforms = transforms.Compose(train_transforms)
        return train_transforms
示例#2
0
    def build_train_transform(self, image_size=None, print_log=True):
        if image_size is None:
            image_size = self.image_size
        if print_log:
            print("Color jitter: %s, resize_scale: %s, img_size: %s" %
                  (self.distort_color, self.resize_scale, image_size))

        if isinstance(image_size, list):
            resize_transform_class = MyRandomResizedCrop
            print(
                "Use MyRandomResizedCrop: %s, \t %s" %
                MyRandomResizedCrop.get_candidate_image_size(),
                "sync=%s, continuous=%s" % (
                    MyRandomResizedCrop.SYNC_DISTRIBUTED,
                    MyRandomResizedCrop.CONTINUOUS,
                ),
            )
        else:
            resize_transform_class = transforms.RandomResizedCrop

        # random_resize_crop -> random_horizontal_flip
        train_transforms = [
            resize_transform_class(image_size,
                                   scale=(self.resize_scale, 1.0),
                                   interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
        ]

        # color augmentation (optional)
        color_transform = None
        if self.distort_color == "torch":
            color_transform = transforms.ColorJitter(brightness=0.4,
                                                     contrast=0.4,
                                                     saturation=0.4,
                                                     hue=0.1)
        elif self.distort_color == "tf":
            color_transform = transforms.ColorJitter(brightness=32.0 / 255.0,
                                                     saturation=0.5)
        if color_transform is not None:
            train_transforms.append(color_transform)

        train_transforms += [
            transforms.ToTensor(),
            self.normalize,
        ]

        train_transforms = transforms.Compose(train_transforms)
        return train_transforms