def load_cifar100(args, **kwargs): list_trans = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] if args.auto_augment: list_trans.append(AutoAugment()) if args.cutout: list_trans.append(Cutout()) list_trans.append(transforms.ToTensor()) list_trans.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))) transform_train = transforms.Compose(list_trans) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_loader = torch.utils.data.DataLoader( datasets.CIFAR100('data', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs, num_workers = 4) test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('data', train=False, download=True, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs, num_workers = 4) metadata = { "input_shape" : (3,32,32), "n_classes" : 100 } return train_loader, test_loader, metadata
def data_loader(args): if args.dataset == 'cifar10': mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) else: raise ValueError('Unavailable dataset "%s"' % (dataset)) transform_train = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] if args.augment == 'AutoAugment': transform_train.append(AutoAugment()) elif args.augment == 'Basic': transform_train.extend([ transforms.RandomApply( [transforms.ColorJitter(0.3, 0.3, 0.3, 0.1)], 0.8), transforms.RandomGrayscale(0.1), ]) else: raise ValueError('No such augmentation policy is set!') transform_train.extend([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) transform_train = transforms.Compose(transform_train) transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) if args.dataset == 'cifar10': train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, download=True, transform=transform_train) val_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=True, transform=transform_val) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=False, num_workers=args.num_workers, pin_memory=True) return train_loader, val_loader
def AutoAug(img: tf.Tensor): img = img.numpy() autoaug = AutoAugment() Auto_aug_im = np.zeros_like(img) for i in range(img.shape[0]): im = img[i] im = Image.fromarray(im) im = autoaug(im) Auto_aug_im[i] = im Auto_aug_im < -tf.convert_to_tensor(Auto_aug_im, dtype=tf.float16) return Auto_aug_im
def __init__(self, data_cfg, multi=1, nl=False): """ Dataset for training. :param data_cfg: CfgNode for CityFlow NL. """ self.nl = nl self.multi = multi self.motion = data_cfg.motion self.nseg = data_cfg.nseg self.all3 = data_cfg.all3 self.pad = data_cfg.pad self.data_cfg = data_cfg self.aug = AutoAugment(auto_augment_policy(name='v0r', hparams=None)) with open(self.data_cfg.JSON_PATH) as f: tracks = json.load(f) f.close() self.list_of_uuids = list(tracks.keys()) self.list_of_tracks = list(tracks.values()) self.list_of_crops = list() train_num = len(self.list_of_uuids) self.transform = transforms.Compose([ transforms.Pad(10), transforms.RandomCrop( (data_cfg.CROP_SIZE, self.data_cfg.CROP_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), RandomErasing(probability=0.5) ]) if data_cfg.semi: #cv with open(self.data_cfg.EVAL_TRACKS_JSON_PATH) as f: unlabel_tracks = json.load(f) f.close() self.list_of_uuids.extend(unlabel_tracks.keys()) self.list_of_tracks.extend(unlabel_tracks.values()) #nl with open("data/test-queries.json", "r") as f: unlabel_nl = json.load(f) unlabel_nl_key = list(unlabel_nl.keys()) print('#track id (class): %d ' % len(self.list_of_tracks)) count = 0 # add id and nl, -1 for unlabeled data for track_idx, track in enumerate(self.list_of_tracks): track["track_id"] = track_idx track["nl_id"] = track_idx # from 0 to train_num-1 is the id of the original training set. if track_idx >= train_num: track["nl_id"] = -1 track["nl"] = unlabel_nl[unlabel_nl_key[count]] count = count + 1 self._logger = get_logger()
def my_transform(train=True, resize=224, use_cutout=False, n_holes=1, length=8, auto_aug=False, rand_aug=False): transforms = [] interpolations = [ PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.HAMMING, PIL.Image.BICUBIC, PIL.Image.LANCZOS ] if train: # transforms.append(T.RandomRotation(90)) transforms.append( T.RandomResizedCrop(resize + 5, scale=(0.2, 2.0), interpolation=PIL.Image.BICUBIC)) transforms.append(T.RandomHorizontalFlip()) # transforms.append(T.RandomVerticalFlip()) transforms.append(T.ColorJitter(0.2, 0.2, 0.3, 0.)) transforms.append(T.CenterCrop(resize)) if auto_aug: transforms.append(AutoAugment()) if rand_aug: transforms.append(Rand_Augment()) else: transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) transforms.append(T.CenterCrop(resize)) transforms.append(T.ToTensor()) transforms.append( # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) # T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) T.Normalize(mean=[0.507, 0.522, 0.500], std=[0.213, 0.207, 0.212])) if train and use_cutout: transforms.append(Cutout()) return T.Compose(transforms)
def train_data_loader(data_path, img_size, use_augment=False): if use_augment: data_transforms = transforms.Compose([ transforms.RandomResizedCrop(img_size), AutoAugment(), Cutout(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) else: data_transforms = transforms.Compose([ transforms.RandomResizedCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image_dataset = datasets.ImageFolder(data_path, data_transforms) #print(image_dataset) return image_dataset
def my_transform(train=True, resize=224, use_cutout=False, n_holes=1, length=8, auto_aug=False, raug=False, N=0, M=0): transforms = [] if train: transforms.append(T.RandomRotation(90)) transforms.append( T.RandomResizedCrop(resize + 20, scale=(0.2, 1.0), interpolation=PIL.Image.BICUBIC)) transforms.append(T.RandomHorizontalFlip()) # transforms.append(T.RandomVerticalFlip()) transforms.append(T.ColorJitter(0.3, 0.2, 0.2, 0.2)) transforms.append(T.CenterCrop(resize)) if auto_aug: transforms.append(AutoAugment()) if raug: transforms.append(Randaugment(N, M)) else: transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) transforms.append(T.CenterCrop(resize)) transforms.append(T.ToTensor()) transforms.append( T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) if train and use_cutout: transforms.append(Cutout()) return T.Compose(transforms)
def __init__(self, mdlParams, indSet): """ Args: mdlParams (dict): Configuration for loading indSet (string): Indicates train, val, test """ # Mdlparams self.wb = WhiteBalancer() self.mdlParams = mdlParams # Number of classes self.numClasses = mdlParams['numClasses'] # Model input size self.input_size = (np.int32(mdlParams['input_size'][0]),np.int32(mdlParams['input_size'][1])) # Whether or not to use ordered cropping self.orderedCrop = mdlParams['orderedCrop'] # Number of crops for multi crop eval self.multiCropEval = mdlParams['multiCropEval'] # Whether during training same-sized crops should be used self.same_sized_crop = mdlParams['same_sized_crops'] # Only downsample self.only_downsmaple = mdlParams.get('only_downsmaple',False) # Potential class balancing option self.balancing = mdlParams['balance_classes'] # Whether data should be preloaded self.preload = mdlParams['preload'] # Potentially subtract a mean self.subtract_set_mean = mdlParams['subtract_set_mean'] # Potential switch for evaluation on the training set self.train_eval_state = mdlParams['trainSetState'] # Potential setMean to deduce from channels self.setMean = mdlParams['setMean'].astype(np.float32) # Current indSet = 'trainInd'/'valInd'/'testInd' self.indices = mdlParams[indSet] self.indSet = indSet # feature scaling for meta if mdlParams.get('meta_features',None) is not None and mdlParams['scale_features']: self.feature_scaler = mdlParams['feature_scaler_meta'] if self.balancing == 3 and indSet == 'trainInd': # Sample classes equally for each batch # First, split set by classes not_one_hot = np.argmax(mdlParams['labels_array'],1) self.class_indices = [] for i in range(mdlParams['numClasses']): self.class_indices.append(np.where(not_one_hot==i)[0]) # Kick out non-trainind indices self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['valInd']) # And test indices if 'testInd' in mdlParams: self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['testInd']) # Now sample indices equally for each batch by repeating all of them to have the same amount as the max number indices = [] max_num = np.max([len(x) for x in self.class_indices]) # Go thourgh all classes for i in range(mdlParams['numClasses']): count = 0 class_count = 0 max_num_curr_class = len(self.class_indices[i]) # Add examples until we reach the maximum while(count < max_num): # Start at the beginning, if we are through all available examples if class_count == max_num_curr_class: class_count = 0 indices.append(self.class_indices[i][class_count]) count += 1 class_count += 1 print("Largest class",max_num,"Indices len",len(indices)) print("Intersect val",np.intersect1d(indices,mdlParams['valInd']),"Intersect Testind",np.intersect1d(indices,mdlParams['testInd'])) # Set labels/inputs self.labels = mdlParams['labels_array'][indices,:] self.im_paths = np.array(mdlParams['im_paths'])[indices].tolist() # Normal train proc if self.same_sized_crop: cropping = transforms.RandomCrop(self.input_size) elif self.only_downsmaple: cropping = transforms.Resize(self.input_size) else: cropping = transforms.RandomResizedCrop(self.input_size[0]) # All transforms self.composed = transforms.Compose([ cropping, transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ColorJitter(brightness=32. / 255.,saturation=0.5), transforms.ToTensor(), transforms.Normalize(torch.from_numpy(self.setMean).float(),torch.from_numpy(np.array([1.,1.,1.])).float()) ]) elif self.orderedCrop and (indSet == 'valInd' or self.train_eval_state == 'eval' or indSet == 'testInd'): # Also flip on top if mdlParams.get('eval_flipping',0) > 1: # Complete labels array, only for current indSet, repeat for multiordercrop inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']*mdlParams['eval_flipping']) self.labels = mdlParams['labels_array'][inds_rep,:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][inds_rep,:] # Path to images for loading, only for current indSet, repeat for multiordercrop self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist() print("len im path",len(self.im_paths)) if self.mdlParams.get('var_im_size',False): self.cropPositions = np.tile(mdlParams['cropPositions'][mdlParams[indSet],:,:],(1,mdlParams['eval_flipping'],1)) self.cropPositions = np.reshape(self.cropPositions,[mdlParams['multiCropEval']*mdlParams['eval_flipping']*mdlParams[indSet].shape[0],2]) #self.cropPositions = np.repeat(self.cropPositions, (mdlParams['eval_flipping'],1)) #print("CP examples",self.cropPositions[:50,:]) else: self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams['eval_flipping']*mdlParams[indSet].shape[0],1)) # Flip states if mdlParams['eval_flipping'] == 2: self.flipPositions = np.array([0,1]) elif mdlParams['eval_flipping'] == 3: self.flipPositions = np.array([0,1,2]) elif mdlParams['eval_flipping'] == 4: self.flipPositions = np.array([0,1,2,3]) self.flipPositions = np.repeat(self.flipPositions, mdlParams['multiCropEval']) self.flipPositions = np.tile(self.flipPositions, mdlParams[indSet].shape[0]) print("Crop positions shape",self.cropPositions.shape,"flip pos shape",self.flipPositions.shape) print("Flip example",self.flipPositions[:30]) else: # Complete labels array, only for current indSet, repeat for multiordercrop inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']) self.labels = mdlParams['labels_array'][inds_rep,:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][inds_rep,:] # Path to images for loading, only for current indSet, repeat for multiordercrop self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist() print("len im path",len(self.im_paths)) # Set up crop positions for every sample if self.mdlParams.get('var_im_size',False): self.cropPositions = np.reshape(mdlParams['cropPositions'][mdlParams[indSet],:,:],[mdlParams['multiCropEval']*mdlParams[indSet].shape[0],2]) #print("CP examples",self.cropPositions[:50,:]) else: self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams[indSet].shape[0],1)) print("CP",self.cropPositions.shape) #print("CP Example",self.cropPositions[0:len(mdlParams['cropPositions']),:]) # Set up transforms self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd'])) self.trans = transforms.ToTensor() elif indSet == 'valInd' or indSet == 'testInd': if self.multiCropEval == 0: if self.only_downsmaple: self.cropping = transforms.Resize(self.input_size) else: self.cropping = transforms.Compose([transforms.CenterCrop(np.int32(self.input_size[0]*1.5)),transforms.Resize(self.input_size)]) # Complete labels array, only for current indSet self.labels = mdlParams['labels_array'][mdlParams[indSet],:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:] # Path to images for loading, only for current indSet self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist() else: # Deterministic processing if self.mdlParams.get('deterministic_eval',False): total_len_per_im = mdlParams['numCropPositions']*len(mdlParams['cropScales'])*mdlParams['cropFlipping'] # Actual transforms are functionally applied at forward pass self.cropPositions = np.zeros([total_len_per_im,3]) ind = 0 for i in range(mdlParams['numCropPositions']): for j in range(len(mdlParams['cropScales'])): for k in range(mdlParams['cropFlipping']): self.cropPositions[ind,0] = i self.cropPositions[ind,1] = mdlParams['cropScales'][j] self.cropPositions[ind,2] = k ind += 1 # Complete labels array, only for current indSet, repeat for multiordercrop print("crops per image",total_len_per_im) self.cropPositions = np.tile(self.cropPositions, (mdlParams[indSet].shape[0],1)) inds_rep = np.repeat(mdlParams[indSet], total_len_per_im) self.labels = mdlParams['labels_array'][inds_rep,:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][inds_rep,:] # Path to images for loading, only for current indSet, repeat for multiordercrop self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist() else: self.cropping = transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0)) # Complete labels array, only for current indSet, repeat for multiordercrop inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']) self.labels = mdlParams['labels_array'][inds_rep,:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][inds_rep,:] # Path to images for loading, only for current indSet, repeat for multiordercrop self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist() print(len(self.im_paths)) # Set up transforms self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd'])) self.trans = transforms.ToTensor() else: all_transforms = [] # Normal train proc if self.same_sized_crop: all_transforms.append(transforms.RandomCrop(self.input_size)) elif self.only_downsmaple: all_transforms.append(transforms.Resize(self.input_size)) else: all_transforms.append(transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0))) if mdlParams.get('flip_lr_ud',False): all_transforms.append(transforms.RandomHorizontalFlip()) all_transforms.append(transforms.RandomVerticalFlip()) # Full rot if mdlParams.get('full_rot',0) > 0: if mdlParams.get('scale',False): all_transforms.append(transforms.RandomChoice([transforms.RandomAffine(mdlParams['full_rot'], scale=mdlParams['scale'], shear=mdlParams.get('shear',0), resample=Image.NEAREST), transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BICUBIC), transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BILINEAR)])) else: all_transforms.append(transforms.RandomChoice([transforms.RandomRotation(mdlParams['full_rot'], resample=Image.NEAREST), transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BICUBIC), transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BILINEAR)])) # Color distortion if mdlParams.get('full_color_distort') is not None: all_transforms.append(transforms.ColorJitter(brightness=mdlParams.get('brightness_aug',32. / 255.),saturation=mdlParams.get('saturation_aug',0.5), contrast = mdlParams.get('contrast_aug',0.5), hue = mdlParams.get('hue_aug',0.2))) else: all_transforms.append(transforms.ColorJitter(brightness=32. / 255.,saturation=0.5)) # Autoaugment if self.mdlParams.get('autoaugment',False): all_transforms.append(AutoAugment()) # Cutout if self.mdlParams.get('cutout',0) > 0: all_transforms.append(Cutout_v0(n_holes=1,length=self.mdlParams['cutout'])) # Normalize all_transforms.append(transforms.ToTensor()) all_transforms.append(transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))) # All transforms self.composed = transforms.Compose(all_transforms) # Complete labels array, only for current indSet self.labels = mdlParams['labels_array'][mdlParams[indSet],:] # meta if mdlParams.get('meta_features',None) is not None: self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:] # Path to images for loading, only for current indSet self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist() # Potentially preload if self.preload: self.im_list = [] for i in range(len(self.im_paths)): self.im_list.append(Image.open(self.im_paths[i]))
def main(): args = parse_args() if args.name is None: args.name = '%s_WideResNet%s-%s' %(args.dataset, args.depth, args.width) if args.cutout: args.name += '_wCutout' if args.auto_augment: args.name += '_wAutoAugment' if not os.path.exists('models/%s' %args.name): os.makedirs('models/%s' %args.name) print('Config -----') for arg in vars(args): print('%s: %s' %(arg, getattr(args, arg))) print('------------') with open('models/%s/args.txt' %args.name, 'w') as f: for arg in vars(args): print('%s: %s' %(arg, getattr(args, arg)), file=f) joblib.dump(args, 'models/%s/args.pkl' %args.name) criterion = nn.CrossEntropyLoss().cuda() cudnn.benchmark = True # data loading code if args.dataset == 'cifar10': transform_train = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] if args.auto_augment: transform_train.append(AutoAugment()) if args.cutout: transform_train.append(Cutout()) transform_train.extend([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_train = transforms.Compose(transform_train) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) train_set = datasets.CIFAR10( root='~/data', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_set, batch_size=128, shuffle=True, num_workers=8) test_set = datasets.CIFAR10( root='~/data', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader( test_set, batch_size=128, shuffle=False, num_workers=8) num_classes = 10 elif args.dataset == 'cifar100': transform_train = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] if args.auto_augment: transform_train.append(AutoAugment()) if args.cutout: transform_train.append(Cutout()) transform_train.extend([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_train = transforms.Compose(transform_train) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_set = datasets.CIFAR100( root='~/data', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_set, batch_size=128, shuffle=True, num_workers=8) test_set = datasets.CIFAR100( root='~/data', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader( test_set, batch_size=128, shuffle=False, num_workers=8) num_classes = 100 # create model model = WideResNet(args.depth, args.width, num_classes=num_classes) model = model.cuda() optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma) log = pd.DataFrame(index=[], columns=[ 'epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc' ]) best_acc = 0 for epoch in range(args.epochs): print('Epoch [%d/%d]' %(epoch+1, args.epochs)) scheduler.step() # train for one epoch train_log = train(args, train_loader, model, criterion, optimizer, epoch) # evaluate on validation set val_log = validate(args, test_loader, model, criterion) print('loss %.4f - acc %.4f - val_loss %.4f - val_acc %.4f' %(train_log['loss'], train_log['acc'], val_log['loss'], val_log['acc'])) tmp = pd.Series([ epoch, scheduler.get_lr()[0], train_log['loss'], train_log['acc'], val_log['loss'], val_log['acc'], ], index=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc']) log = log.append(tmp, ignore_index=True) log.to_csv('models/%s/log.csv' %args.name, index=False) if val_log['acc'] > best_acc: torch.save(model.state_dict(), 'models/%s/model.pth' %args.name) best_acc = val_log['acc'] print("=> saved best model")
def train(self): torch.multiprocessing.set_sharing_strategy('file_system') path = self.args.data_path label_file = self.args.label_path self.logger.info('original train process') time_stamp_launch = time.strftime('%Y%m%d') + '-' + time.strftime( '%H%M') self.logger.info(path.split('/')[-2] + time_stamp_launch) best_acc = 0 model_root = './model_' + path.split('/')[-2] if not os.path.exists(model_root): os.mkdir(model_root) cuda = True cudnn.benchmark = True batch_size = self.args.batchsize batch_size_g = batch_size * 2 image_size = (224, 224) num_cls = self.args.num_class self.generator_epoch = self.args.generator_epoch self.warm_epoch = 10 n_epoch = self.args.max_epoch weight_decay = 1e-6 momentum = 0.9 manual_seed = random.randint(1, 10000) random.seed(manual_seed) torch.manual_seed(manual_seed) ####################### # load data # ####################### target_train = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop((224, 224)), transforms.RandomHorizontalFlip(), AutoAugment(), transforms.ToTensor(), transforms.Normalize((0.435, 0.418, 0.396), (0.284, 0.308, 0.335)), # grayscale mean/std ]) dataset_train = visDataset_target(path, label_file, train=True, transform=target_train) dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=3) transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.435, 0.418, 0.396), (0.284, 0.308, 0.335)), # grayscale mean/std ]) test_dataset = visDataset_target(path, label_file, train=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=3) ##################### # load model # ##################### self.lemniscate = LinearAverage(2048, test_dataset.__len__(), 0.05, 0.00).cuda() self.elr_loss = elr_loss(num_examp=test_dataset.__len__(), num_classes=12).cuda() generator = generator_fea_deconv(class_num=num_cls) discriminator = Discriminator_fea() source_net = torch.load(self.args.source_model_path) source_classifier = Classifier(num_classes=num_cls) fea_contrastor = contrastor() # load pre-trained source classifier fc_dict = source_classifier.state_dict() pre_dict = source_net.state_dict() pre_dict = {k: v for k, v in pre_dict.items() if k in fc_dict} fc_dict.update(pre_dict) source_classifier.load_state_dict(fc_dict) generator = DataParallel(generator, device_ids=[0, 1]) discriminator = DataParallel(discriminator, device_ids=[0, 1]) fea_contrastor = DataParallel(fea_contrastor, device_ids=[0, 1]) source_net = DataParallel(source_net, device_ids=[0, 1]) source_classifier = DataParallel(source_classifier, device_ids=[0, 1]) source_classifier.eval() for p in generator.parameters(): p.requires_grad = True for p in source_net.parameters(): p.requires_grad = True # freezing the source classifier for name, value in source_net.named_parameters(): if name[:9] == 'module.fc': value.requires_grad = False # setup optimizer params = filter(lambda p: p.requires_grad, source_net.parameters()) discriminator_group = [] for k, v in discriminator.named_parameters(): discriminator_group += [{'params': v, 'lr': self.lr * 3}] model_params = [] for v in params: model_params += [{'params': v, 'lr': self.lr}] contrastor_para = [] for k, v in fea_contrastor.named_parameters(): contrastor_para += [{'params': v, 'lr': self.lr * 5}] ##################### # setup optimizer # ##################### # only train the extractor optimizer = optim.SGD(model_params + discriminator_group + contrastor_para, momentum=momentum, weight_decay=weight_decay) optimizer_g = optim.SGD(generator.parameters(), lr=self.lr, momentum=momentum, weight_decay=weight_decay) loss_gen_ce = torch.nn.CrossEntropyLoss() if cuda: source_net = source_net.cuda() generator = generator.cuda() discriminator = discriminator.cuda() fea_contrastor = fea_contrastor.cuda() loss_gen_ce = loss_gen_ce.cuda() source_classifier = source_classifier.cuda() ############################# # training network # ############################# len_dataloader = len(dataloader_train) self.logger.info('the step of one epoch: ' + str(len_dataloader)) current_step = 0 for epoch in range(n_epoch): source_net.train() discriminator.train() fea_contrastor.train() data_train_iter = iter(dataloader_train) if epoch < self.generator_epoch: generator.train() self.train_prototype_generator(epoch, batch_size_g, num_cls, optimizer_g, generator, source_classifier, loss_gen_ce) if epoch >= self.generator_epoch: if epoch == self.generator_epoch: torch.save( generator, model_root + '/generator_' + path.split('/')[-2] + '.pkl') # prototype generation generator.eval() z = Variable(torch.rand(self.args.num_class * 2, 100)).cuda() # Get labels ranging from 0 to n_classes for n rows label_t = torch.linspace(0, num_cls - 1, steps=num_cls).long() for ti in range(self.args.num_class * 2 // num_cls - 1): label_t = torch.cat([ label_t, torch.linspace(0, num_cls - 1, steps=num_cls).long() ]) labels = Variable(label_t).cuda() z = z.contiguous() labels = labels.contiguous() images = generator(z, labels) self.alpha = 0.9 - (epoch - self.generator_epoch) / ( n_epoch - self.generator_epoch) * 0.2 # obtain the target pseudo label and confidence weight pseudo_label, pseudo_label_acc, all_indx, confidence_weight = self.obtain_pseudo_label_and_confidence_weight( test_loader, source_net) i = 0 while i < len_dataloader: ################################### # prototype adaptation # ################################### p = float(i + (epoch - self.generator_epoch) * len_dataloader ) / (n_epoch - self.generator_epoch) / len_dataloader self.p = 2. / (1. + np.exp(-10 * p)) - 1 data_target_train = data_train_iter.next() s_img, s_label, s_indx = data_target_train batch_size_s = len(s_label) input_img_s = torch.FloatTensor(batch_size_s, 3, image_size[0], image_size[1]) class_label_s = torch.LongTensor(batch_size_s) if cuda: s_img = s_img.cuda() s_label = s_label.cuda() input_img_s = input_img_s.cuda() class_label_s = class_label_s.cuda() input_img_s.resize_as_(s_img).copy_(s_img) class_label_s.resize_as_(s_label).copy_(s_label) target_inputv_img = Variable(input_img_s) target_classv_label = Variable(class_label_s) # learning rate decay optimizer = self.exp_lr_scheduler(optimizer=optimizer, step=current_step) loss, contrastive_loss = self.adaptation_step( target_inputv_img, pseudo_label, images.detach(), labels, s_indx.numpy(), source_net, discriminator, fea_contrastor, optimizer, epoch, confidence_weight.float()) # visualization on tensorboard self.writer.add_scalar('contrastive_loss', contrastive_loss, global_step=current_step) self.writer.add_scalar('overall_loss', loss, global_step=current_step) self.writer.add_scalar('pseudo_label_acc', pseudo_label_acc, global_step=current_step) i += 1 current_step += 1 self.logger.info('epoch: %d' % epoch) self.logger.info('contrastive_loss: %f' % (contrastive_loss)) self.logger.info('loss: %f' % loss) accu, ac_list = val_pclass(source_net, test_loader) self.writer.add_scalar('test_acc', accu, global_step=current_step) self.logger.info(ac_list) if accu >= best_acc: self.logger.info('saving the best model!') torch.save( source_net, model_root + '/' + time_stamp_launch + '_best_model_' + path.split('/')[-2] + '.pkl') best_acc = accu self.logger.info('acc is : %.04f, best acc is : %.04f' % (accu, best_acc)) self.logger.info( '================================================') self.logger.info('training done! ! !')