def test_dataparallel_params(model, params): device = torch.device('cuda:0') model = DataParallel(model) model.to(device=device) inputs = torch.rand(5, 2).to(device=device) outputs = model(inputs, params=params) assert outputs.shape == (5, 1) assert outputs.device == device
def __init__(self, model_dict, dataset, device): super().__init__() self.dataset = dataset self.model_dict = model_dict if self.model_dict['name'] == 'resnet18': if self.model_dict['pretrained']: self.net = models.resnet18(pretrained=True) self.net.fc = nn.Linear(512, self.dataset.n_classes) else: self.net = models.resnet18(num_classes=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta': if self.model_dict.get('pretrained', True): self.net = resnet_meta.resnet18(pretrained=True) self.net.fc = MetaLinear(512, self.dataset.n_classes) else: self.net = resnet_meta.resnet18( num_classes=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta_2': self.net = resnet_meta_2.ResNet18(nc=3, nclasses=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta_old': self.net = resnet_meta_old.ResNet18( nc=3, nclasses=self.dataset.n_classes) else: raise ValueError('network %s does not exist' % model_dict['name']) if (device.type == 'cuda'): self.net = DataParallel(self.net) self.net.to(device) # set optimizer self.opt_dict = model_dict['opt'] self.lr_init = self.opt_dict['lr'] if self.model_dict['opt']['name'] == 'sps': n_batches_per_epoch = 120 self.opt = sps.Sps(self.net.parameters(), n_batches_per_epoch=n_batches_per_epoch, c=0.5, adapt_flag='smooth_iter', eps=0, eta_max=None) else: self.opt = optim.SGD(self.net.parameters(), lr=self.opt_dict['lr'], momentum=self.opt_dict['momentum'], weight_decay=self.opt_dict['weight_decay']) # variables self.device = device
def test_dataparallel_params_maml(model): device = torch.device('cuda:0') model = DataParallel(model) model.to(device=device) train_inputs = torch.rand(5, 2).to(device=device) train_outputs = model(train_inputs) inner_loss = train_outputs.sum() # Dummy loss params = gradient_update_parameters(model, inner_loss) test_inputs = torch.rand(5, 2).to(device=device) test_outputs = model(test_inputs, params=params) assert test_outputs.shape == (5, 1) assert test_outputs.device == device outer_loss = test_outputs.sum() # Dummy loss outer_loss.backward()
def __init__(self, model_dict, dataset, device): super().__init__() if model_dict['name'] == 'stn': self.net = stn.STN(isize=dataset.image_size, n_channels=dataset.nc, n_filters=64, nz=100, datasetmean=dataset.mean, datasetstd=dataset.std) elif model_dict['name'] == 'small_affine': self.net = small_affine.smallAffine( nz=6, transformation=model_dict['transform'], datasetmean=dataset.mean, datasetstd=dataset.std) elif model_dict['name'] == 'affine_color': self.net = affine_color.affineColor(nz=10, datasetmean=dataset.mean, datasetstd=dataset.std) else: raise ValueError('network %s does not exist' % model_dict['name']) if (device.type == 'cuda'): self.net = DataParallel(self.net) self.net.to(device) self.device = device self.factor = model_dict['factor'] self.name = model_dict['name'] if model_dict['name'] != 'random_augmenter': self.opt_dict = model_dict['opt'] self.lr_init = self.opt_dict['lr'] self.opt = optim.SGD(self.net.parameters(), lr=self.opt_dict['lr'], momentum=self.opt_dict['momentum'], weight_decay=self.opt_dict['weight_decay'])
torch.backends.cudnn.benchmark = True elif args.use_cuda: raise RuntimeError('You are using GPU mode, but GPUs are not available!') # construct model and optimizer args.image_len = 28 if args.train_data == 'omniglot' else 84 args.out_channels, _ = get_outputs_c_h(args.backbone, args.image_len) model = ConvolutionalNeuralNetwork(args.backbone, args.out_channels, args.num_ways) teacher_model = ConvolutionalNeuralNetwork(args.teacher_backbone, args.out_channels, args.num_ways) if args.use_cuda: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" num_gpus = torch.cuda.device_count() if args.multi_gpu: model = DataParallel(model) teacher_model = DataParallel(teacher_model) model = model.cuda() teacher_model = teacher_model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) # download teacher trained model if args.teacher_resume and args.teacher_resume_folder is not None: teacher_model_path = os.path.join(args.teacher_resume_folder, ('_'.join(['teacher', args.train_data, args.test_data, args.teacher_backbone, 'max_acc']) + '.pt')) teacher_state = torch.load(teacher_model_path) if args.multi_gpu: teacher_model.module.load_state_dict(teacher_state) else: teacher_model.load_state_dict(teacher_state)
class Classifier(nn.Module): def __init__(self, model_dict, dataset, device): super().__init__() self.dataset = dataset self.model_dict = model_dict if self.model_dict['name'] == 'resnet18': if self.model_dict['pretrained']: self.net = models.resnet18(pretrained=True) self.net.fc = nn.Linear(512, self.dataset.n_classes) else: self.net = models.resnet18(num_classes=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta': if self.model_dict.get('pretrained', True): self.net = resnet_meta.resnet18(pretrained=True) self.net.fc = MetaLinear(512, self.dataset.n_classes) else: self.net = resnet_meta.resnet18( num_classes=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta_2': self.net = resnet_meta_2.ResNet18(nc=3, nclasses=self.dataset.n_classes) elif self.model_dict['name'] == 'resnet18_meta_old': self.net = resnet_meta_old.ResNet18( nc=3, nclasses=self.dataset.n_classes) else: raise ValueError('network %s does not exist' % model_dict['name']) if (device.type == 'cuda'): self.net = DataParallel(self.net) self.net.to(device) # set optimizer self.opt_dict = model_dict['opt'] self.lr_init = self.opt_dict['lr'] if self.model_dict['opt']['name'] == 'sps': n_batches_per_epoch = 120 self.opt = sps.Sps(self.net.parameters(), n_batches_per_epoch=n_batches_per_epoch, c=0.5, adapt_flag='smooth_iter', eps=0, eta_max=None) else: self.opt = optim.SGD(self.net.parameters(), lr=self.opt_dict['lr'], momentum=self.opt_dict['momentum'], weight_decay=self.opt_dict['weight_decay']) # variables self.device = device def get_state_dict(self): state_dict = { 'net': self.net.state_dict(), 'opt': self.opt.state_dict(), } return state_dict def load_state_dict(self, state_dict): self.net.load_state_dict(state_dict['net']) self.opt.load_state_dict(state_dict['opt']) def on_trainloader_start(self, epoch): if self.opt_dict['sched']: ut.adjust_learning_rate_netC(self.opt, epoch, self.lr_init, self.model_dict['name'], self.dataset.name) def train_on_batch(self, batch): images, labels = batch['images'].to( self.device, non_blocking=True), batch['labels'].to(self.device, non_blocking=True) logits = self.net(images) loss = F.cross_entropy(logits, labels, reduction="mean") self.opt.zero_grad() loss.backward() if self.opt_dict['name'] == 'sps': self.opt.step(loss=loss) else: self.opt.step() # print(ut.compute_parameter_sum(self)) return loss.item()
class Augmenter(nn.Module): def __init__(self, model_dict, dataset, device): super().__init__() if model_dict['name'] == 'stn': self.net = stn.STN(isize=dataset.image_size, n_channels=dataset.nc, n_filters=64, nz=100, datasetmean=dataset.mean, datasetstd=dataset.std) elif model_dict['name'] == 'small_affine': self.net = small_affine.smallAffine( nz=6, transformation=model_dict['transform'], datasetmean=dataset.mean, datasetstd=dataset.std) elif model_dict['name'] == 'affine_color': self.net = affine_color.affineColor(nz=10, datasetmean=dataset.mean, datasetstd=dataset.std) else: raise ValueError('network %s does not exist' % model_dict['name']) if (device.type == 'cuda'): self.net = DataParallel(self.net) self.net.to(device) self.device = device self.factor = model_dict['factor'] self.name = model_dict['name'] if model_dict['name'] != 'random_augmenter': self.opt_dict = model_dict['opt'] self.lr_init = self.opt_dict['lr'] self.opt = optim.SGD(self.net.parameters(), lr=self.opt_dict['lr'], momentum=self.opt_dict['momentum'], weight_decay=self.opt_dict['weight_decay']) def cycle(self, iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) def get_state_dict(self): state_dict = {} if hasattr(self, 'opt'): state_dict['net'] = self.net.state_dict() state_dict['opt'] = self.opt.state_dict() return state_dict def load_state_dict(self, state_dict): if hasattr(self, 'opt'): self.net.load_state_dict(state_dict['net']) self.opt.load_state_dict(state_dict['opt']) def apply_augmentation(self, images, labels): # apply augmentation to the given images factor = self.factor if factor > 1: labels = labels.repeat(factor) images = images.repeat(factor, 1, 1, 1) with torch.autograd.set_detect_anomaly(True): augimages, transformations = self.net(images) return augimages, labels, transformations def on_trainloader_start(self, epoch, valloader, netC): # Get slope if hasattr(self, 'slope_annealing'): self.slope = get_slope(self.slope_annealing, epoch) # Update optimizer if self.opt_dict['sched']: ut.adjust_learning_rate_netA(self.optimizerA, epoch, self.lrA_init) # initialize momentums if netC.opt.defaults['momentum']: self.moms = OrderedDict() for (name, p) in netC.net.named_parameters(): self.moms[name] = torch.zeros(p.shape).to(self.device) self.epoch = epoch # Cycle through val_loader self.val_gen = self.cycle(valloader) def train_on_batch(self, batch, netC): self.train() images, labels = batch['images'].to( self.device, non_blocking=True), batch['labels'].to(self.device, non_blocking=True) images, labels, transformations = self.apply_augmentation( images, labels) # Use classifier logits = netC.net(images) loss_clf = F.cross_entropy(logits, labels, reduction="mean") netC.opt.zero_grad() if self.name in ['random_augmenter']: # Update the classifier only loss_clf.backward() netC.opt.step() return loss_clf elif self.name in ['stn']: # Update the style transformer network self.opt.zero_grad() loss_clf.backward() netC.opt.step() self.opt.step() return loss_clf else: # Update the augmenter through a validation batch # Calculate new weights w^t+1 to calculate the validation loss batch_val = next(self.val_gen) valimages, vallabels = batch_val['images'].to( self.device, non_blocking=True), batch_val['labels'].to(self.device, non_blocking=True) # construct graph loss_clf.backward(create_graph=True, retain_graph=True) # for p in netC.net.parameters(): # p.requires_grad = False # freeze C self.w_t_1 = OrderedDict() lr = ut.adjust_learning_rate_netC(netC.opt, self.epoch, netC.lr_init, netC.model_dict['name'], netC.dataset.name, return_lr=True) # get step size if netC.opt.defaults['momentum']: # for name in self.moms: # self.moms[name].detach_() for (name, p) in netC.net.named_parameters(): p.requires_grad = False self.moms[name].detach_() self.moms[ name] = netC.opt.defaults['momentum'] * self.moms[ name] + p.grad # update momentums self.w_t_1[name] = p - lr * self.moms[ name] # compute future weights else: for (name, p) in netC.net.named_parameters(): p.requires_grad = False self.w_t_1[ name] = p - lr * p.grad # compute future weights # Calculate validation loss valoutput = netC.net(valimages, params=self.w_t_1) loss_aug = F.cross_entropy(valoutput, vallabels, reduction='mean') self.opt.zero_grad() loss_aug.backward() self.opt.step() del self.w_t_1 # After gradient is computed for A, unfreeze C for p in netC.net.parameters(): p.requires_grad = True netC.opt.step() del images del labels gc.collect() return float(loss_clf.item()), transformations def __call__(self, img): img = img.unsqueeze(0) img, _ = self.net.forward(img) img = img.squeeze(0) return img
elif args.use_cuda: raise RuntimeError( 'You are using GPU mode, but GPUs are not available!') # construct model and optimizer args.image_len = 28 if args.train_data == 'omniglot' else 84 args.out_channels, _ = get_outputs_c_h(args.backbone, args.image_len) model = ConvolutionalNeuralNetwork(args.backbone, args.out_channels, args.num_ways) if args.use_cuda: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" num_gpus = torch.cuda.device_count() if args.multi_gpu: model = DataParallel(model) model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) # training from the checkpoint if args.resume and args.resume_folder is not None: # load checkpoint checkpoint_path = os.path.join(args.resume_folder, ('_'.join([ args.model_name, args.train_data, args.test_data, args.backbone, 'max_acc' ]) + '_checkpoint.pt.tar')) # tag='max_acc' can be changed state = torch.load(checkpoint_path)