class PytorchLearner(MachineLearningInterface): """ Pytorch learner implementation of machine learning interface """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, vote_loader: torch.utils.data.DataLoader, test_loader: Optional[torch.utils.data.DataLoader] = None, need_reset_optimizer: bool = True, device=_DEFAULT_DEVICE, criterion: Optional[_Loss] = None, minimise_criterion=True, vote_criterion: Optional[Callable[[torch.Tensor, torch.Tensor], float]] = None, num_train_batches: Optional[int] = None, num_test_batches: Optional[int] = None, diff_priv_config: Optional[DiffPrivConfig] = None, ): """ :param model: Pytorch model used for training :param optimizer: Training optimizer :param train_loader: Train dataset :param test_loader: Optional test dataset - subset of training set will be used if not specified :param need_reset_optimizer: True to clear optimizer history before training, False to kepp history. :param device: Pytorch device - CPU or GPU :param criterion: Loss function :param minimise_criterion: True to minimise value of criterion, False to maximise :param vote_criterion: Function to measure model performance for voting :param num_train_batches: Number of training batches :param num_test_batches: Number of testing batches :param diff_priv_config: Contains differential privacy (dp) budget related configuration """ # Model has to be on same device as data self.model: torch.nn.Module = model.to(device) self.optimizer: torch.optim.Optimizer = optimizer self.criterion = criterion self.train_loader: torch.utils.data.DataLoader = train_loader self.vote_loader: torch.utils.data.DataLoader = vote_loader self.test_loader: Optional[torch.utils.data.DataLoader] = test_loader self.need_reset_optimizer = need_reset_optimizer self.device = device self.num_train_batches = num_train_batches or len(train_loader) self.num_test_batches = num_test_batches self.minimise_criterion = minimise_criterion self.vote_criterion = vote_criterion self.dp_config = diff_priv_config self.dp_privacy_engine = PrivacyEngine() if diff_priv_config is not None: ( self.model, self.optimizer, self.train_loader, ) = self.dp_privacy_engine.make_private( module=self.model, optimizer=self.optimizer, data_loader=self.train_loader, max_grad_norm=diff_priv_config.max_grad_norm, noise_multiplier=diff_priv_config.noise_multiplier, ) self.vote_score = self.test(self.vote_loader) def mli_get_current_weights(self) -> Weights: """ :return: The current weights of the model """ current_state_dict = OrderedDict() for key in self.model.state_dict(): current_state_dict[key] = self.model.state_dict()[key].clone() w = Weights( weights=current_state_dict, training_summary=self.get_training_summary() ) return w def mli_get_current_model(self) -> ColearnModel: """ :return: The current model and its format """ return ColearnModel( model_format=ModelFormat(ModelFormat.ONNX), model_file="", model=convert_model_to_onnx(self.model), ) def set_weights(self, weights: Weights): """ Rewrites weight of current model :param weights: Weights to be stored """ self.model.load_state_dict(weights.weights) def reset_optimizer(self): """ Clear optimizer state, such as number of iterations, momentums. This way, the outdated history can be erased. """ self.optimizer.__setstate__({"state": defaultdict(dict)}) def train(self): """ Trains the model on the training dataset """ if self.need_reset_optimizer: # erase the outdated optimizer memory (momentums mostly) self.reset_optimizer() self.model.train() for batch_idx, (data, labels) in enumerate(self.train_loader): if batch_idx == self.num_train_batches: break self.optimizer.zero_grad() # Data needs to be on same device as model data = data.to(self.device) labels = labels.to(self.device) output = self.model(data) loss = self.criterion(output, labels) loss.backward() self.optimizer.step() def mli_propose_weights(self) -> Weights: """ Trains model on training set and returns new weights after training - Current model is reverted to original state after training :return: Weights after training """ current_weights = self.mli_get_current_weights() training_summary = current_weights.training_summary if ( training_summary is not None and training_summary.error_code is not None and training_summary.error_code == ErrorCodes.DP_BUDGET_EXCEEDED ): return current_weights self.train() new_weights = self.mli_get_current_weights() self.set_weights(current_weights) training_summary = new_weights.training_summary if ( training_summary is not None and training_summary.error_code is not None and training_summary.error_code == ErrorCodes.DP_BUDGET_EXCEEDED ): current_weights.training_summary = training_summary return current_weights return new_weights def mli_test_weights(self, weights: Weights) -> ProposedWeights: """ Tests given weights on training and test set and returns weights with score values :param weights: Weights to be tested :return: ProposedWeights - Weights with vote and test score """ current_weights = self.mli_get_current_weights() self.set_weights(weights) vote_score = self.test(self.vote_loader) if self.test_loader: test_score = self.test(self.test_loader) else: test_score = 0 vote = self.vote(vote_score) self.set_weights(current_weights) return ProposedWeights( weights=weights, vote_score=vote_score, test_score=test_score, vote=vote ) def vote(self, new_score) -> bool: """ Compares current model score with proposed model score and returns vote :param new_score: Proposed score :return: bool positive or negative vote """ if self.minimise_criterion: return new_score < self.vote_score else: return new_score > self.vote_score def test(self, loader: torch.utils.data.DataLoader) -> float: """ Tests performance of the model on specified dataset :param loader: Dataset for testing :return: Value of performance metric """ if not self.criterion: raise Exception("Criterion is unspecified so test method cannot be used") self.model.eval() total_score = 0 all_labels = [] all_outputs = [] batch_idx = 0 total_samples = 0 with torch.no_grad(): for batch_idx, (data, labels) in enumerate(loader): total_samples += labels.shape[0] if self.num_test_batches and batch_idx == self.num_test_batches: break data = data.to(self.device) labels = labels.to(self.device) output = self.model(data) if self.vote_criterion is not None: all_labels.append(labels) all_outputs.append(output) else: total_score += self.criterion(output, labels).item() if batch_idx == 0: raise Exception("No batches in loader") if self.vote_criterion is None: return float(total_score / total_samples) else: return self.vote_criterion( torch.cat(all_outputs, dim=0), torch.cat(all_labels, dim=0) ) def mli_accept_weights(self, weights: Weights): """ Updates the model with the proposed set of weights :param weights: The new weights """ self.set_weights(weights) self.vote_score = self.test(self.vote_loader) def get_training_summary(self) -> Optional[TrainingSummary]: """ Differential Privacy Budget :return: the target and consumed epsilon so far """ if self.dp_config is None: return None delta = self.dp_config.target_delta target_epsilon = self.dp_config.target_epsilon consumed_epsilon = self.dp_privacy_engine.get_epsilon(delta) budget = DiffPrivBudget( target_epsilon=target_epsilon, consumed_epsilon=consumed_epsilon, target_delta=delta, consumed_delta=delta, # delta is constatnt per training ) err = ( ErrorCodes.DP_BUDGET_EXCEEDED if consumed_epsilon >= target_epsilon else None ) return TrainingSummary( dp_budget=budget, error_code=err, )
def main(): args = parse_args() if args.debug >= 1: logger.setLevel(level=logging.DEBUG) device = args.device if args.secure_rng: try: import torchcsprng as prng except ImportError as e: msg = ( "To use secure RNG, you must install the torchcsprng package! " "Check out the instructions here: https://github.com/pytorch/csprng#installation" ) raise ImportError(msg) from e generator = prng.create_random_device_generator("/dev/urandom") else: generator = None augmentations = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] normalize = [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] train_transform = transforms.Compose( augmentations + normalize if args.disable_dp else normalize) test_transform = transforms.Compose(normalize) train_dataset = CIFAR10(root=args.data_root, train=True, download=True, transform=train_transform) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=int(args.sample_rate * len(train_dataset)), generator=generator, num_workers=args.workers, pin_memory=True, ) test_dataset = CIFAR10(root=args.data_root, train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=args.workers, ) best_acc1 = 0 model = models.__dict__[args.architecture]( pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c))) model = model.to(device) if args.optim == "SGD": optimizer = optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optim == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=args.lr) elif args.optim == "Adam": optimizer = optim.Adam(model.parameters(), lr=args.lr) else: raise NotImplementedError( "Optimizer not recognized. Please check spelling") privacy_engine = None if not args.disable_dp: if args.clip_per_layer: # Each layer has the same clipping threshold. The total grad norm is still bounded by `args.max_per_sample_grad_norm`. n_layers = len([(n, p) for n, p in model.named_parameters() if p.requires_grad]) max_grad_norm = [ args.max_per_sample_grad_norm / np.sqrt(n_layers) ] * n_layers else: max_grad_norm = args.max_per_sample_grad_norm privacy_engine = PrivacyEngine(secure_mode=args.secure_rng, ) clipping = "per_layer" if args.clip_per_layer else "flat" model, optimizer, train_loader = privacy_engine.make_private( module=model, optimizer=optimizer, data_loader=train_loader, noise_multiplier=args.sigma, max_grad_norm=max_grad_norm, clipping=clipping, ) # Store some logs accuracy_per_epoch = [] time_per_epoch = [] for epoch in range(args.start_epoch, args.epochs + 1): if args.lr_schedule == "cos": lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1))) for param_group in optimizer.param_groups: param_group["lr"] = lr train_duration = train(args, model, train_loader, optimizer, privacy_engine, epoch, device) top1_acc = test(args, model, test_loader, device) # remember best acc@1 and save checkpoint is_best = top1_acc > best_acc1 best_acc1 = max(top1_acc, best_acc1) time_per_epoch.append(train_duration) accuracy_per_epoch.append(float(top1_acc)) save_checkpoint( { "epoch": epoch + 1, "arch": "Convnet", "state_dict": model.state_dict(), "best_acc1": best_acc1, "optimizer": optimizer.state_dict(), }, is_best, filename=args.checkpoint_file + ".tar", ) time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch] avg_time_per_epoch = sum(time_per_epoch_seconds) / len( time_per_epoch_seconds) metrics = { "accuracy": best_acc1, "accuracy_per_epoch": accuracy_per_epoch, "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))), "time_per_epoch": time_per_epoch_seconds, } logger.info( "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n" ) logger.info(metrics)