Пример #1
0
def load_cifar100(args, **kwargs):
    list_trans = [        
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip()
        ]

    if args.auto_augment:
        list_trans.append(AutoAugment())
    if args.cutout:
        list_trans.append(Cutout())

    list_trans.append(transforms.ToTensor())
    list_trans.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))

    transform_train = transforms.Compose(list_trans)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('data', train=True, download=True, transform=transform_train),
        batch_size=args.batch_size, shuffle=True, **kwargs, num_workers = 4)

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('data', train=False, download=True, transform=transform_test),
        batch_size=args.batch_size, shuffle=True, **kwargs, num_workers = 4)

    metadata = {
        "input_shape" : (3,32,32),
        "n_classes" : 100
    }

    return train_loader, test_loader, metadata
Пример #2
0
def data_loader(args):
    if args.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    else:
        raise ValueError('Unavailable dataset "%s"' % (dataset))

    transform_train = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]

    if args.augment == 'AutoAugment':
        transform_train.append(AutoAugment())
    elif args.augment == 'Basic':
        transform_train.extend([
            transforms.RandomApply(
                [transforms.ColorJitter(0.3, 0.3, 0.3, 0.1)], 0.8),
            transforms.RandomGrayscale(0.1),
        ])
    else:
        raise ValueError('No such augmentation policy is set!')

    transform_train.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_train = transforms.Compose(transform_train)

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    if args.dataset == 'cifar10':
        train_set = torchvision.datasets.CIFAR10(root='./dataset',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        val_set = torchvision.datasets.CIFAR10(root='./dataset',
                                               train=False,
                                               download=True,
                                               transform=transform_val)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=64,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    return train_loader, val_loader
def AutoAug(img: tf.Tensor):
    img = img.numpy()
    autoaug = AutoAugment()
    Auto_aug_im = np.zeros_like(img)
    for i in range(img.shape[0]):
        im = img[i]
        im = Image.fromarray(im)
        im = autoaug(im)
        Auto_aug_im[i] = im
    Auto_aug_im < -tf.convert_to_tensor(Auto_aug_im, dtype=tf.float16)
    return Auto_aug_im
Пример #4
0
    def __init__(self, data_cfg, multi=1, nl=False):
        """
        Dataset for training.
        :param data_cfg: CfgNode for CityFlow NL.
        """
        self.nl = nl
        self.multi = multi
        self.motion = data_cfg.motion
        self.nseg = data_cfg.nseg
        self.all3 = data_cfg.all3
        self.pad = data_cfg.pad
        self.data_cfg = data_cfg
        self.aug = AutoAugment(auto_augment_policy(name='v0r', hparams=None))
        with open(self.data_cfg.JSON_PATH) as f:
            tracks = json.load(f)
        f.close()
        self.list_of_uuids = list(tracks.keys())
        self.list_of_tracks = list(tracks.values())
        self.list_of_crops = list()
        train_num = len(self.list_of_uuids)
        self.transform = transforms.Compose([
            transforms.Pad(10),
            transforms.RandomCrop(
                (data_cfg.CROP_SIZE, self.data_cfg.CROP_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            RandomErasing(probability=0.5)
        ])

        if data_cfg.semi:
            #cv
            with open(self.data_cfg.EVAL_TRACKS_JSON_PATH) as f:
                unlabel_tracks = json.load(f)
            f.close()
            self.list_of_uuids.extend(unlabel_tracks.keys())
            self.list_of_tracks.extend(unlabel_tracks.values())
            #nl
            with open("data/test-queries.json", "r") as f:
                unlabel_nl = json.load(f)
            unlabel_nl_key = list(unlabel_nl.keys())

        print('#track id (class): %d ' % len(self.list_of_tracks))
        count = 0
        # add id and nl, -1 for unlabeled data
        for track_idx, track in enumerate(self.list_of_tracks):
            track["track_id"] = track_idx
            track["nl_id"] = track_idx
            # from 0 to train_num-1 is the id of the original training set.
            if track_idx >= train_num:
                track["nl_id"] = -1
                track["nl"] = unlabel_nl[unlabel_nl_key[count]]
                count = count + 1
        self._logger = get_logger()
Пример #5
0
def my_transform(train=True,
                 resize=224,
                 use_cutout=False,
                 n_holes=1,
                 length=8,
                 auto_aug=False,
                 rand_aug=False):
    transforms = []
    interpolations = [
        PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.HAMMING,
        PIL.Image.BICUBIC, PIL.Image.LANCZOS
    ]

    if train:
        # transforms.append(T.RandomRotation(90))
        transforms.append(
            T.RandomResizedCrop(resize + 5,
                                scale=(0.2, 2.0),
                                interpolation=PIL.Image.BICUBIC))
        transforms.append(T.RandomHorizontalFlip())
        # transforms.append(T.RandomVerticalFlip())
        transforms.append(T.ColorJitter(0.2, 0.2, 0.3, 0.))
        transforms.append(T.CenterCrop(resize))
        if auto_aug:
            transforms.append(AutoAugment())
        if rand_aug:
            transforms.append(Rand_Augment())
    else:
        transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
        transforms.append(T.CenterCrop(resize))

    transforms.append(T.ToTensor())
    transforms.append(
        # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        # T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        T.Normalize(mean=[0.507, 0.522, 0.500], std=[0.213, 0.207, 0.212]))

    if train and use_cutout:
        transforms.append(Cutout())

    return T.Compose(transforms)
def train_data_loader(data_path, img_size, use_augment=False):
    if use_augment:
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(img_size),
            AutoAugment(),
            Cutout(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    image_dataset = datasets.ImageFolder(data_path, data_transforms)
    #print(image_dataset)

    return image_dataset
Пример #7
0
def my_transform(train=True,
                 resize=224,
                 use_cutout=False,
                 n_holes=1,
                 length=8,
                 auto_aug=False,
                 raug=False,
                 N=0,
                 M=0):
    transforms = []

    if train:
        transforms.append(T.RandomRotation(90))
        transforms.append(
            T.RandomResizedCrop(resize + 20,
                                scale=(0.2, 1.0),
                                interpolation=PIL.Image.BICUBIC))
        transforms.append(T.RandomHorizontalFlip())
        # transforms.append(T.RandomVerticalFlip())
        transforms.append(T.ColorJitter(0.3, 0.2, 0.2, 0.2))
        transforms.append(T.CenterCrop(resize))
        if auto_aug:
            transforms.append(AutoAugment())
        if raug:
            transforms.append(Randaugment(N, M))

    else:
        transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
        transforms.append(T.CenterCrop(resize))

    transforms.append(T.ToTensor())
    transforms.append(
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    if train and use_cutout:
        transforms.append(Cutout())

    return T.Compose(transforms)
Пример #8
0
 def __init__(self, mdlParams, indSet):
     """
     Args:
         mdlParams (dict): Configuration for loading
         indSet (string): Indicates train, val, test
     """
     # Mdlparams
     self.wb = WhiteBalancer()
     self.mdlParams = mdlParams
     # Number of classes
     self.numClasses = mdlParams['numClasses']
     # Model input size
     self.input_size = (np.int32(mdlParams['input_size'][0]),np.int32(mdlParams['input_size'][1]))
     # Whether or not to use ordered cropping
     self.orderedCrop = mdlParams['orderedCrop']
     # Number of crops for multi crop eval
     self.multiCropEval = mdlParams['multiCropEval']
     # Whether during training same-sized crops should be used
     self.same_sized_crop = mdlParams['same_sized_crops']
     # Only downsample
     self.only_downsmaple = mdlParams.get('only_downsmaple',False)
     # Potential class balancing option
     self.balancing = mdlParams['balance_classes']
     # Whether data should be preloaded
     self.preload = mdlParams['preload']
     # Potentially subtract a mean
     self.subtract_set_mean = mdlParams['subtract_set_mean']
     # Potential switch for evaluation on the training set
     self.train_eval_state = mdlParams['trainSetState']
     # Potential setMean to deduce from channels
     self.setMean = mdlParams['setMean'].astype(np.float32)
     # Current indSet = 'trainInd'/'valInd'/'testInd'
     self.indices = mdlParams[indSet]
     self.indSet = indSet
     # feature scaling for meta
     if mdlParams.get('meta_features',None) is not None and mdlParams['scale_features']:
         self.feature_scaler = mdlParams['feature_scaler_meta']
     if self.balancing == 3 and indSet == 'trainInd':
         # Sample classes equally for each batch
         # First, split set by classes
         not_one_hot = np.argmax(mdlParams['labels_array'],1)
         self.class_indices = []
         for i in range(mdlParams['numClasses']):
             self.class_indices.append(np.where(not_one_hot==i)[0])
             # Kick out non-trainind indices
             self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['valInd'])
             # And test indices
             if 'testInd' in mdlParams:
                 self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['testInd'])
         # Now sample indices equally for each batch by repeating all of them to have the same amount as the max number
         indices = []
         max_num = np.max([len(x) for x in self.class_indices])
         # Go thourgh all classes
         for i in range(mdlParams['numClasses']):
             count = 0
             class_count = 0
             max_num_curr_class = len(self.class_indices[i])
             # Add examples until we reach the maximum
             while(count < max_num):
                 # Start at the beginning, if we are through all available examples
                 if class_count == max_num_curr_class:
                     class_count = 0
                 indices.append(self.class_indices[i][class_count])
                 count += 1
                 class_count += 1
         print("Largest class",max_num,"Indices len",len(indices))
         print("Intersect val",np.intersect1d(indices,mdlParams['valInd']),"Intersect Testind",np.intersect1d(indices,mdlParams['testInd']))
         # Set labels/inputs
         self.labels = mdlParams['labels_array'][indices,:]
         self.im_paths = np.array(mdlParams['im_paths'])[indices].tolist()
         # Normal train proc
         if self.same_sized_crop:
             cropping = transforms.RandomCrop(self.input_size)
         elif self.only_downsmaple:
             cropping = transforms.Resize(self.input_size)
         else:
             cropping = transforms.RandomResizedCrop(self.input_size[0])
         # All transforms
         self.composed = transforms.Compose([
                 cropping,
                 transforms.RandomHorizontalFlip(),
                 transforms.RandomVerticalFlip(),
                 transforms.ColorJitter(brightness=32. / 255.,saturation=0.5),
                 transforms.ToTensor(),
                 transforms.Normalize(torch.from_numpy(self.setMean).float(),torch.from_numpy(np.array([1.,1.,1.])).float())
                 ])
     elif self.orderedCrop and (indSet == 'valInd' or self.train_eval_state  == 'eval' or indSet == 'testInd'):
         # Also flip on top
         if mdlParams.get('eval_flipping',0) > 1:
             # Complete labels array, only for current indSet, repeat for multiordercrop
             inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']*mdlParams['eval_flipping'])
             self.labels = mdlParams['labels_array'][inds_rep,:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][inds_rep,:]
             # Path to images for loading, only for current indSet, repeat for multiordercrop
             self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             print("len im path",len(self.im_paths))
             if self.mdlParams.get('var_im_size',False):
                 self.cropPositions = np.tile(mdlParams['cropPositions'][mdlParams[indSet],:,:],(1,mdlParams['eval_flipping'],1))
                 self.cropPositions = np.reshape(self.cropPositions,[mdlParams['multiCropEval']*mdlParams['eval_flipping']*mdlParams[indSet].shape[0],2])
                 #self.cropPositions = np.repeat(self.cropPositions, (mdlParams['eval_flipping'],1))
                 #print("CP examples",self.cropPositions[:50,:])
             else:
                 self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams['eval_flipping']*mdlParams[indSet].shape[0],1))
             # Flip states
             if mdlParams['eval_flipping'] == 2:
                 self.flipPositions = np.array([0,1])
             elif mdlParams['eval_flipping'] == 3:
                 self.flipPositions = np.array([0,1,2])
             elif mdlParams['eval_flipping'] == 4:
                 self.flipPositions = np.array([0,1,2,3])
             self.flipPositions = np.repeat(self.flipPositions, mdlParams['multiCropEval'])
             self.flipPositions = np.tile(self.flipPositions, mdlParams[indSet].shape[0])
             print("Crop positions shape",self.cropPositions.shape,"flip pos shape",self.flipPositions.shape)
             print("Flip example",self.flipPositions[:30])
         else:
             # Complete labels array, only for current indSet, repeat for multiordercrop
             inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
             self.labels = mdlParams['labels_array'][inds_rep,:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][inds_rep,:]
             # Path to images for loading, only for current indSet, repeat for multiordercrop
             self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             print("len im path",len(self.im_paths))
             # Set up crop positions for every sample
             if self.mdlParams.get('var_im_size',False):
                 self.cropPositions = np.reshape(mdlParams['cropPositions'][mdlParams[indSet],:,:],[mdlParams['multiCropEval']*mdlParams[indSet].shape[0],2])
                 #print("CP examples",self.cropPositions[:50,:])
             else:
                 self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams[indSet].shape[0],1))
             print("CP",self.cropPositions.shape)
         #print("CP Example",self.cropPositions[0:len(mdlParams['cropPositions']),:])
         # Set up transforms
         self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
         self.trans = transforms.ToTensor()
     elif indSet == 'valInd' or indSet == 'testInd':
         if self.multiCropEval == 0:
             if self.only_downsmaple:
                 self.cropping = transforms.Resize(self.input_size)
             else:
                 self.cropping = transforms.Compose([transforms.CenterCrop(np.int32(self.input_size[0]*1.5)),transforms.Resize(self.input_size)])
             # Complete labels array, only for current indSet
             self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]
             # Path to images for loading, only for current indSet
             self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()
         else:
             # Deterministic processing
             if self.mdlParams.get('deterministic_eval',False):
                 total_len_per_im = mdlParams['numCropPositions']*len(mdlParams['cropScales'])*mdlParams['cropFlipping']
                 # Actual transforms are functionally applied at forward pass
                 self.cropPositions = np.zeros([total_len_per_im,3])
                 ind = 0
                 for i in range(mdlParams['numCropPositions']):
                     for j in range(len(mdlParams['cropScales'])):
                         for k in range(mdlParams['cropFlipping']):
                             self.cropPositions[ind,0] = i
                             self.cropPositions[ind,1] = mdlParams['cropScales'][j]
                             self.cropPositions[ind,2] = k
                             ind += 1
                 # Complete labels array, only for current indSet, repeat for multiordercrop
                 print("crops per image",total_len_per_im)
                 self.cropPositions = np.tile(self.cropPositions, (mdlParams[indSet].shape[0],1))
                 inds_rep = np.repeat(mdlParams[indSet], total_len_per_im)
                 self.labels = mdlParams['labels_array'][inds_rep,:]
                 # meta
                 if mdlParams.get('meta_features',None) is not None:
                     self.meta_data = mdlParams['meta_array'][inds_rep,:]
                 # Path to images for loading, only for current indSet, repeat for multiordercrop
                 self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             else:
                 self.cropping = transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0))
                 # Complete labels array, only for current indSet, repeat for multiordercrop
                 inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
                 self.labels = mdlParams['labels_array'][inds_rep,:]
                 # meta
                 if mdlParams.get('meta_features',None) is not None:
                     self.meta_data = mdlParams['meta_array'][inds_rep,:]
                 # Path to images for loading, only for current indSet, repeat for multiordercrop
                 self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
         print(len(self.im_paths))
         # Set up transforms
         self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
         self.trans = transforms.ToTensor()
     else:
         all_transforms = []
         # Normal train proc
         if self.same_sized_crop:
             all_transforms.append(transforms.RandomCrop(self.input_size))
         elif self.only_downsmaple:
             all_transforms.append(transforms.Resize(self.input_size))
         else:
             all_transforms.append(transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0)))
         if mdlParams.get('flip_lr_ud',False):
             all_transforms.append(transforms.RandomHorizontalFlip())
             all_transforms.append(transforms.RandomVerticalFlip())
         # Full rot
         if mdlParams.get('full_rot',0) > 0:
             if mdlParams.get('scale',False):
                 all_transforms.append(transforms.RandomChoice([transforms.RandomAffine(mdlParams['full_rot'], scale=mdlParams['scale'], shear=mdlParams.get('shear',0), resample=Image.NEAREST),
                                                             transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BICUBIC),
                                                             transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BILINEAR)]))
             else:
                 all_transforms.append(transforms.RandomChoice([transforms.RandomRotation(mdlParams['full_rot'], resample=Image.NEAREST),
                                                             transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BICUBIC),
                                                             transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BILINEAR)]))
         # Color distortion
         if mdlParams.get('full_color_distort') is not None:
             all_transforms.append(transforms.ColorJitter(brightness=mdlParams.get('brightness_aug',32. / 255.),saturation=mdlParams.get('saturation_aug',0.5), contrast = mdlParams.get('contrast_aug',0.5), hue = mdlParams.get('hue_aug',0.2)))
         else:
             all_transforms.append(transforms.ColorJitter(brightness=32. / 255.,saturation=0.5))
         # Autoaugment
         if self.mdlParams.get('autoaugment',False):
             all_transforms.append(AutoAugment())
         # Cutout
         if self.mdlParams.get('cutout',0) > 0:
             all_transforms.append(Cutout_v0(n_holes=1,length=self.mdlParams['cutout']))
         # Normalize
         all_transforms.append(transforms.ToTensor())
         all_transforms.append(transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd'])))
         # All transforms
         self.composed = transforms.Compose(all_transforms)
         # Complete labels array, only for current indSet
         self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
         # meta
         if mdlParams.get('meta_features',None) is not None:
             self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]
         # Path to images for loading, only for current indSet
         self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()
     # Potentially preload
     if self.preload:
         self.im_list = []
         for i in range(len(self.im_paths)):
             self.im_list.append(Image.open(self.im_paths[i]))
Пример #9
0
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_WideResNet%s-%s' %(args.dataset, args.depth, args.width)
        if args.cutout:
            args.name += '_wCutout'
        if args.auto_augment:
            args.name += '_wAutoAugment'

    if not os.path.exists('models/%s' %args.name):
        os.makedirs('models/%s' %args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' %(arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' %args.name, 'w') as f:
        for arg in vars(args):
            print('%s: %s' %(arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' %args.name)

    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # data loading code
    if args.dataset == 'cifar10':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if args.auto_augment:
            transform_train.append(AutoAugment())
        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761)),
        ])

        train_set = datasets.CIFAR10(
            root='~/data',
            train=True,
            download=True,
            transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=128,
            shuffle=True,
            num_workers=8)

        test_set = datasets.CIFAR10(
            root='~/data',
            train=False,
            download=True,
            transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=128,
            shuffle=False,
            num_workers=8)

        num_classes = 10

    elif args.dataset == 'cifar100':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if args.auto_augment:
            transform_train.append(AutoAugment())
        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR100(
            root='~/data',
            train=True,
            download=True,
            transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=128,
            shuffle=True,
            num_workers=8)

        test_set = datasets.CIFAR100(
            root='~/data',
            train=False,
            download=True,
            transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=128,
            shuffle=False,
            num_workers=8)

        num_classes = 100

    # create model
    model = WideResNet(args.depth, args.width, num_classes=num_classes)
    model = model.cuda()

    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
            momentum=args.momentum, weight_decay=args.weight_decay)

    scheduler = lr_scheduler.MultiStepLR(optimizer,
            milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma)

    log = pd.DataFrame(index=[], columns=[
        'epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'
    ])

    best_acc = 0
    for epoch in range(args.epochs):
        print('Epoch [%d/%d]' %(epoch+1, args.epochs))

        scheduler.step()

        # train for one epoch
        train_log = train(args, train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        val_log = validate(args, test_loader, model, criterion)

        print('loss %.4f - acc %.4f - val_loss %.4f - val_acc %.4f'
            %(train_log['loss'], train_log['acc'], val_log['loss'], val_log['acc']))

        tmp = pd.Series([
            epoch,
            scheduler.get_lr()[0],
            train_log['loss'],
            train_log['acc'],
            val_log['loss'],
            val_log['acc'],
        ], index=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'])

        log = log.append(tmp, ignore_index=True)
        log.to_csv('models/%s/log.csv' %args.name, index=False)

        if val_log['acc'] > best_acc:
            torch.save(model.state_dict(), 'models/%s/model.pth' %args.name)
            best_acc = val_log['acc']
            print("=> saved best model")
Пример #10
0
    def train(self):
        torch.multiprocessing.set_sharing_strategy('file_system')

        path = self.args.data_path
        label_file = self.args.label_path
        self.logger.info('original train process')
        time_stamp_launch = time.strftime('%Y%m%d') + '-' + time.strftime(
            '%H%M')
        self.logger.info(path.split('/')[-2] + time_stamp_launch)
        best_acc = 0
        model_root = './model_' + path.split('/')[-2]
        if not os.path.exists(model_root):
            os.mkdir(model_root)
        cuda = True
        cudnn.benchmark = True
        batch_size = self.args.batchsize
        batch_size_g = batch_size * 2
        image_size = (224, 224)
        num_cls = self.args.num_class

        self.generator_epoch = self.args.generator_epoch
        self.warm_epoch = 10
        n_epoch = self.args.max_epoch
        weight_decay = 1e-6
        momentum = 0.9

        manual_seed = random.randint(1, 10000)
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        #######################
        # load data           #
        #######################
        target_train = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            AutoAugment(),
            transforms.ToTensor(),
            transforms.Normalize((0.435, 0.418, 0.396),
                                 (0.284, 0.308, 0.335)),  # grayscale mean/std
        ])

        dataset_train = visDataset_target(path,
                                          label_file,
                                          train=True,
                                          transform=target_train)

        dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=3)
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.435, 0.418, 0.396),
                                 (0.284, 0.308, 0.335)),  # grayscale mean/std
        ])

        test_dataset = visDataset_target(path,
                                         label_file,
                                         train=True,
                                         transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=3)

        #####################
        #  load model       #
        #####################
        self.lemniscate = LinearAverage(2048, test_dataset.__len__(), 0.05,
                                        0.00).cuda()
        self.elr_loss = elr_loss(num_examp=test_dataset.__len__(),
                                 num_classes=12).cuda()

        generator = generator_fea_deconv(class_num=num_cls)

        discriminator = Discriminator_fea()
        source_net = torch.load(self.args.source_model_path)
        source_classifier = Classifier(num_classes=num_cls)
        fea_contrastor = contrastor()

        # load pre-trained source classifier
        fc_dict = source_classifier.state_dict()
        pre_dict = source_net.state_dict()
        pre_dict = {k: v for k, v in pre_dict.items() if k in fc_dict}
        fc_dict.update(pre_dict)
        source_classifier.load_state_dict(fc_dict)

        generator = DataParallel(generator, device_ids=[0, 1])
        discriminator = DataParallel(discriminator, device_ids=[0, 1])
        fea_contrastor = DataParallel(fea_contrastor, device_ids=[0, 1])
        source_net = DataParallel(source_net, device_ids=[0, 1])
        source_classifier = DataParallel(source_classifier, device_ids=[0, 1])
        source_classifier.eval()

        for p in generator.parameters():
            p.requires_grad = True
        for p in source_net.parameters():
            p.requires_grad = True

        # freezing the source classifier
        for name, value in source_net.named_parameters():
            if name[:9] == 'module.fc':
                value.requires_grad = False

        # setup optimizer
        params = filter(lambda p: p.requires_grad, source_net.parameters())
        discriminator_group = []
        for k, v in discriminator.named_parameters():
            discriminator_group += [{'params': v, 'lr': self.lr * 3}]

        model_params = []
        for v in params:
            model_params += [{'params': v, 'lr': self.lr}]

        contrastor_para = []
        for k, v in fea_contrastor.named_parameters():
            contrastor_para += [{'params': v, 'lr': self.lr * 5}]

        #####################
        # setup optimizer   #
        #####################

        # only train the extractor
        optimizer = optim.SGD(model_params + discriminator_group +
                              contrastor_para,
                              momentum=momentum,
                              weight_decay=weight_decay)
        optimizer_g = optim.SGD(generator.parameters(),
                                lr=self.lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

        loss_gen_ce = torch.nn.CrossEntropyLoss()

        if cuda:
            source_net = source_net.cuda()
            generator = generator.cuda()
            discriminator = discriminator.cuda()
            fea_contrastor = fea_contrastor.cuda()
            loss_gen_ce = loss_gen_ce.cuda()
            source_classifier = source_classifier.cuda()

        #############################
        # training network          #
        #############################

        len_dataloader = len(dataloader_train)
        self.logger.info('the step of one epoch: ' + str(len_dataloader))

        current_step = 0
        for epoch in range(n_epoch):
            source_net.train()
            discriminator.train()
            fea_contrastor.train()

            data_train_iter = iter(dataloader_train)

            if epoch < self.generator_epoch:
                generator.train()
                self.train_prototype_generator(epoch, batch_size_g, num_cls,
                                               optimizer_g, generator,
                                               source_classifier, loss_gen_ce)

            if epoch >= self.generator_epoch:
                if epoch == self.generator_epoch:
                    torch.save(
                        generator, model_root + '/generator_' +
                        path.split('/')[-2] + '.pkl')

                # prototype generation
                generator.eval()
                z = Variable(torch.rand(self.args.num_class * 2, 100)).cuda()

                # Get labels ranging from 0 to n_classes for n rows
                label_t = torch.linspace(0, num_cls - 1, steps=num_cls).long()
                for ti in range(self.args.num_class * 2 // num_cls - 1):
                    label_t = torch.cat([
                        label_t,
                        torch.linspace(0, num_cls - 1, steps=num_cls).long()
                    ])
                labels = Variable(label_t).cuda()
                z = z.contiguous()
                labels = labels.contiguous()
                images = generator(z, labels)

                self.alpha = 0.9 - (epoch - self.generator_epoch) / (
                    n_epoch - self.generator_epoch) * 0.2

                # obtain the target pseudo label and confidence weight
                pseudo_label, pseudo_label_acc, all_indx, confidence_weight = self.obtain_pseudo_label_and_confidence_weight(
                    test_loader, source_net)

                i = 0
                while i < len_dataloader:
                    ###################################
                    #        prototype adaptation         #
                    ###################################
                    p = float(i +
                              (epoch - self.generator_epoch) * len_dataloader
                              ) / (n_epoch -
                                   self.generator_epoch) / len_dataloader
                    self.p = 2. / (1. + np.exp(-10 * p)) - 1
                    data_target_train = data_train_iter.next()
                    s_img, s_label, s_indx = data_target_train

                    batch_size_s = len(s_label)

                    input_img_s = torch.FloatTensor(batch_size_s, 3,
                                                    image_size[0],
                                                    image_size[1])
                    class_label_s = torch.LongTensor(batch_size_s)

                    if cuda:
                        s_img = s_img.cuda()
                        s_label = s_label.cuda()
                        input_img_s = input_img_s.cuda()
                        class_label_s = class_label_s.cuda()

                    input_img_s.resize_as_(s_img).copy_(s_img)
                    class_label_s.resize_as_(s_label).copy_(s_label)
                    target_inputv_img = Variable(input_img_s)
                    target_classv_label = Variable(class_label_s)

                    # learning rate decay
                    optimizer = self.exp_lr_scheduler(optimizer=optimizer,
                                                      step=current_step)

                    loss, contrastive_loss = self.adaptation_step(
                        target_inputv_img, pseudo_label, images.detach(),
                        labels, s_indx.numpy(), source_net, discriminator,
                        fea_contrastor, optimizer, epoch,
                        confidence_weight.float())

                    # visualization on tensorboard
                    self.writer.add_scalar('contrastive_loss',
                                           contrastive_loss,
                                           global_step=current_step)
                    self.writer.add_scalar('overall_loss',
                                           loss,
                                           global_step=current_step)
                    self.writer.add_scalar('pseudo_label_acc',
                                           pseudo_label_acc,
                                           global_step=current_step)

                    i += 1
                    current_step += 1

                self.logger.info('epoch: %d' % epoch)
                self.logger.info('contrastive_loss: %f' % (contrastive_loss))
                self.logger.info('loss: %f' % loss)
                accu, ac_list = val_pclass(source_net, test_loader)
                self.writer.add_scalar('test_acc',
                                       accu,
                                       global_step=current_step)
                self.logger.info(ac_list)
                if accu >= best_acc:
                    self.logger.info('saving the best model!')
                    torch.save(
                        source_net, model_root + '/' + time_stamp_launch +
                        '_best_model_' + path.split('/')[-2] + '.pkl')
                    best_acc = accu

                self.logger.info('acc is : %.04f, best acc is : %.04f' %
                                 (accu, best_acc))
                self.logger.info(
                    '================================================')

        self.logger.info('training done! ! !')