def test_milestones(self): self.assertLrEquals(0.1) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) fields = [ 'optimizer', 'lr_schedule', 'learning_rate', 'momentum', 'weight_decay', 'lr_gamma', 'lr_milestone_steps' ] params = ['SGD', 'LambdaLR', 0.1, 0.5, 0.0, 0.1, '2ep,4ep,7ep,8ep'] Config().trainer = namedtuple('trainer', fields)(*params) self.assertLrEquals(0.1) lrs = optimizers.get_lr_schedule(self.optimizer, 10) self.assertLrEquals(0.1) for _ in range(19): lrs.step() self.assertLrEquals(1e-1) for _ in range(1): lrs.step() self.assertLrEquals(1e-2) for _ in range(19): lrs.step() self.assertLrEquals(1e-2) for _ in range(1): lrs.step() self.assertLrEquals(1e-3) for _ in range(29): lrs.step() self.assertLrEquals(1e-3) for _ in range(1): lrs.step() self.assertLrEquals(1e-4) for _ in range(9): lrs.step() self.assertLrEquals(1e-4) for _ in range(1): lrs.step() self.assertLrEquals(1e-5) for _ in range(100): lrs.step() self.assertLrEquals(1e-5)
def test_vanilla(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) fields = [ 'optimizer', 'lr_schedule', 'learning_rate', 'momentum', 'weight_decay', 'lr_gamma' ] params = ['SGD', 'LambdaLR', 0.1, 0.5, 0.0, 0.0] Config().trainer = namedtuple('trainer', fields)(*params) lrs = optimizers.get_lr_schedule(self.optimizer, 10) self.assertLrEquals(0.1) for _ in range(100): lrs.step() self.assertLrEquals(0.1) self.assertLrEquals(0.1)
def train_process(self, config, trainset, evalset, sampler, blending_weights): log_interval = config.log_config["interval"] batch_size = config.trainer['batch_size'] logging.info("[Client #%d] Loading the dataset.", self.client_id) # prepare traindata loaders train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler.get(), num_workers=config.data.get( 'workers_per_gpu', 1)) eval_loader = torch.utils.data.DataLoader(dataset=evalset, shuffle=False, batch_size=batch_size, sampler=sampler.get(), num_workers=config.data.get( 'workers_per_gpu', 1)) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule(optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None # operate the local training supported_modalities = trainset.supported_modalities # in order to blend the gradients in the server side # The eval/train loss of the first and last epoches should be recorded for epoch in range(1, epochs + 1): epoch_train_losses = { modl_nm: 0.0 for modl_nm in supported_modalities } total_batches = 0 total_epoch_loss = 0 for batch_id, (multimodal_examples, labels) in enumerate(train_loader): labels = labels.to(self.device) optimizer.zero_grad() losses = self.model.forward(data_container=multimodal_examples, label=labels, return_loss=True) weighted_losses = self.reweight_losses(blending_weights, losses) # added the losses weighted_global_loss = 0 for modl_nm in supported_modalities: epoch_train_losses[modl_nm] += weighted_losses[modl_nm] weighted_global_loss += weighted_losses[modl_nm] total_epoch_loss += weighted_global_loss weighted_global_loss.backward() optimizer.step() if lr_schedule is not None: lr_schedule.step() if batch_id % log_interval == 0: if self.client_id == 0: logging.info( "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}". format(os.getpid(), epoch, epochs, batch_id, len(train_loader), weighted_losses.data.item())) else: if hasattr(config, 'use_wandb'): wandb.log( {"batch loss": weighted_losses.data.item()}) logging.info( "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}". format(self.client_id, epoch, epochs, batch_id, len(train_loader), weighted_losses.data.item())) total_batches = batch_id if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() # only record the first and final performance of the local epoches if epoch == 1 or epoch == epochs: epoch_avg_train_loss = total_epoch_loss / (total_batches + 1) eval_avg_losses = self.eval_step(eval_data_loader=eval_loader) weighted_eval_losses = self.reweight_losses( blending_weights, eval_avg_losses) total_eval_loss = 0 for modl_nm in supported_modalities: modl_train_avg_loss = epoch_train_losses[ modl_nm] / total_batches modl_eval_avg_loss = eval_avg_losses[modl_nm] if modl_nm not in list( self.mm_train_losses_trajectory.keys()): self.mm_train_losses_trajectory[ modl_nm] = modl_train_avg_loss else: self.mm_train_losses_trajectory[modl_nm].append( modl_train_avg_loss) if modl_nm not in list( self.mm_val_losses_trajectory.keys()): self.mm_val_losses_trajectory[ modl_nm] = modl_eval_avg_loss else: self.mm_val_losses_trajectory[modl_nm].append( modl_eval_avg_loss) total_eval_loss += weighted_eval_losses[modl_nm] # store the global losses self.global_mm_train_losses_trajectory.append( epoch_avg_train_loss) self.global_mm_val_losses_trajectory.append(total_eval_loss) self.model.cpu() model_type = config['model_name'] filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth" self.save_model(filename) if 'use_wandb' in config: run = wandb.init(project="plato", group=str(config['run_id']), reinit=True) if 'use_wandb' in config: run.finish()
def train_model(self, config, trainset, sampler, cut_layer=None): """ A custom trainer reporting training loss. """ batch_size = config['batch_size'] log_interval = 10 tic = time.perf_counter() logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler, cut_layer) else: train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = torch.nn.CrossEntropyLoss() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule(optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None if 'differential_privacy' in config and config['differential_privacy']: privacy_engine = PrivacyEngine(accountant='rdp', secure_mode=False) self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=self.model, optimizer=optimizer, data_loader=train_loader, target_epsilon=config['dp_epsilon'] if 'dp_epsilon' in config else 10.0, target_delta=config['dp_delta'] if 'dp_delta' in config else 1e-5, epochs=epochs, max_grad_norm=config['dp_max_grad_norm'] if 'max_grad_norm' in config else 1.0, ) for epoch in range(1, epochs + 1): # Use a default training loop for batch_id, (examples, labels) in enumerate(train_loader): examples, labels = examples.to(self.device), labels.to( self.device) if 'differential_privacy' in config and config[ 'differential_privacy']: optimizer.zero_grad(set_to_none=True) else: optimizer.zero_grad() if cut_layer is None: outputs = self.model(examples) else: outputs = self.model.forward_from(examples, cut_layer) loss = loss_criterion(outputs, labels) loss.backward() optimizer.step() if batch_id % log_interval == 0: if self.client_id == 0: logging.info( "[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f", os.getpid(), epoch, epochs, batch_id, len(train_loader), loss.data.item()) else: logging.info( "[Client #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f", self.client_id, epoch, epochs, batch_id, len(train_loader), loss.data.item()) if lr_schedule is not None: lr_schedule.step() if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() # Simulate client's speed if self.client_id != 0 and hasattr( Config().clients, "speed_simulation") and Config().clients.speed_simulation: self.simulate_sleep_time() # Saving the model at the end of this epoch to a file so that # it can later be retrieved to respond to server requests # in asynchronous mode when the wall clock time is simulated if hasattr(Config().server, 'request_update') and Config().server.request_update: self.model.cpu() training_time = time.perf_counter() - tic filename = f"{self.client_id}_{epoch}_{training_time}.pth" self.save_model(filename) self.model.to(self.device) # Save the training loss of the last epoch in this round model_name = config['model_name'] filename = f'{model_name}_{self.client_id}.loss' Trainer.save_loss(loss.data.item(), filename)
def train_model(self, config, trainset, sampler, cut_layer=None): batch_size = config['batch_size'] logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler, cut_layer) else: train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = nn.CrossEntropyLoss() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(Config().trainer, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule(optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None logging.info("[Client #%d] Begining to train.", self.client_id) for __, (examples, labels) in enumerate(train_loader): examples, labels = examples.to(self.device), labels.to(self.device) optimizer.zero_grad() examples = examples.detach().requires_grad_(True) if cut_layer is None: outputs = self.model(examples) else: outputs = self.model.forward_from(examples, cut_layer) loss = loss_criterion(outputs, labels) logging.info("[Client #{}] \tLoss: {:.6f}".format( self.client_id, loss.data.item())) loss.backward() # Record gradients within the cut layer self.cut_layer_grad.append(examples.grad.clone().detach()) optimizer.step() if lr_schedule is not None: lr_schedule.step() if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() self.save_gradients()
def train_process(self, config, trainset, sampler, cut_layer=None): """The main training loop in a federated learning workload, run in a separate process with a new CUDA context, so that CUDA memory can be released after the training completes. Arguments: self: the trainer itself. config: a dictionary of configuration parameters. trainset: The training dataset. sampler: the sampler that extracts a partition for this client. cut_layer (optional): The layer which training should start from. """ if 'use_wandb' in config: import wandb run = wandb.init(project="plato", group=str(config['run_id']), reinit=True) try: custom_train = getattr(self, "train_model", None) if callable(custom_train): self.train_model(config, trainset, sampler.get(), cut_layer) else: log_interval = 10 batch_size = config['batch_size'] logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler.get(), cut_layer) else: train_loader = torch.utils.data.DataLoader( dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler.get()) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = nn.CrossEntropyLoss() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule( optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None if 'differential_privacy' in config and config[ 'differential_privacy']: privacy_engine = PrivacyEngine(accountant='rdp', secure_mode=False) self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=self.model, optimizer=optimizer, data_loader=train_loader, target_epsilon=config['dp_epsilon'] if 'dp_epsilon' in config else 10.0, target_delta=config['dp_delta'] if 'dp_delta' in config else 1e-5, epochs=epochs, max_grad_norm=config['dp_max_grad_norm'] if 'max_grad_norm' in config else 1.0, ) for epoch in range(1, epochs + 1): for batch_id, (examples, labels) in enumerate(train_loader): examples, labels = examples.to(self.device), labels.to( self.device) if 'differential_privacy' in config and config[ 'differential_privacy']: optimizer.zero_grad(set_to_none=True) else: optimizer.zero_grad() if cut_layer is None: outputs = self.model(examples) else: outputs = self.model.forward_from( examples, cut_layer) loss = loss_criterion(outputs, labels) loss.backward() optimizer.step() if batch_id % log_interval == 0: if self.client_id == 0: logging.info( "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}" .format(os.getpid(), epoch, epochs, batch_id, len(train_loader), loss.data.item())) else: if hasattr(config, 'use_wandb'): wandb.log({"batch loss": loss.data.item()}) logging.info( "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}" .format(self.client_id, epoch, epochs, batch_id, len(train_loader), loss.data.item())) if lr_schedule is not None: lr_schedule.step() if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() except Exception as training_exception: logging.info("Training on client #%d failed.", self.client_id) raise training_exception if 'max_concurrency' in config: self.model.cpu() model_type = config['model_name'] filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth" self.save_model(filename) if 'use_wandb' in config: run.finish()
def train_model(self, config, trainset, sampler, cut_layer): """ The custom training loop for Sub-FedAvg(Un). """ batch_size = config['batch_size'] log_interval = 10 logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler, cut_layer) else: train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] if not self.made_init_mask: self.mask = pruning_processor.make_init_mask(self.model) self.made_init_mask = True # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = torch.nn.CrossEntropyLoss() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule(optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None for epoch in range(1, epochs + 1): # Use a default training loop for batch_id, (examples, labels) in enumerate(train_loader): examples, labels = examples.to(self.device), labels.to( self.device) optimizer.zero_grad() if cut_layer is None: outputs = self.model(examples) else: outputs = self.model.forward_from(examples, cut_layer) loss = loss_criterion(outputs, labels) loss.backward() # Freezing Pruned weights by making their gradients Zero step = 0 for name, parameter in self.model.named_parameters(): if 'weight' in name: grad_tensor = parameter.grad.data.cpu().numpy() grad_tensor = grad_tensor * self.mask[step] parameter.grad.data = torch.from_numpy(grad_tensor).to( self.device) step = step + 1 optimizer.step() if batch_id % log_interval == 0: if self.client_id == 0: logging.info( "[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f", os.getpid(), epoch, epochs, batch_id, len(train_loader), loss.data.item()) if lr_schedule is not None: lr_schedule.step() if epoch == 1: first_epoch_mask = pruning_processor.fake_prune( self.pruning_amount, copy.deepcopy(self.model), copy.deepcopy(self.mask)) if epoch == epochs: last_epoch_mask = pruning_processor.fake_prune( self.pruning_amount, copy.deepcopy(self.model), copy.deepcopy(self.mask)) self.process_pruning(first_epoch_mask, last_epoch_mask)
def train_process(self, config, trainset, sampler, cut_layer=None): """The main training loop in a federated learning workload.""" if 'use_wandb' in config: import wandb run = wandb.init(project="plato", group=str(config['run_id']), reinit=True) try: custom_train = getattr(self, "train_model", None) if callable(custom_train): self.train_model(config, trainset, sampler.get(), cut_layer) else: log_interval = 10 batch_size = config['batch_size'] logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler.get(), cut_layer) else: train_loader = torch.utils.data.DataLoader( dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler.get()) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = nn.CrossEntropyLoss() # Initializing the optimizer for the second stage of MAML # The learning rate here is the meta learning rate (beta) optimizer = torch.optim.SGD( self.model.parameters(), lr=Config().trainer.meta_learning_rate, momentum=Config().trainer.momentum, weight_decay=Config().trainer.weight_decay) # Initializing the schedule for meta learning rate, if necessary if hasattr(config, 'meta_lr_schedule'): meta_lr_schedule = optimizers.get_lr_schedule( optimizer, iterations_per_epoch, train_loader) else: meta_lr_schedule = None for epoch in range(1, epochs + 1): # Copy the current model due to using MAML current_model = copy.deepcopy(self.model) # Sending this model to the device used for training current_model.to(self.device) current_model.train() # Initializing the optimizer for the first stage of MAML # The learning rate here is the alpha in the paper temp_optimizer = torch.optim.SGD( current_model.parameters(), lr=Config().trainer.learning_rate, momentum=Config().trainer.momentum, weight_decay=Config().trainer.weight_decay) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule( temp_optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None # The first stage of MAML # Use half of the training dataset self.training_per_stage(1, temp_optimizer, lr_schedule, train_loader, cut_layer, current_model, loss_criterion, log_interval, config, epoch, epochs) # The second stage of MAML # Use the other half of the training dataset self.training_per_stage(2, optimizer, meta_lr_schedule, train_loader, cut_layer, self.model, loss_criterion, log_interval, config, epoch, epochs) if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() except Exception as training_exception: logging.info("Training on client #%d failed.", self.client_id) raise training_exception if 'max_concurrency' in config: self.model.cpu() model_type = config['model_name'] filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth" self.save_model(filename) if 'use_wandb' in config: run.finish()
def train_model(self, config, trainset, sampler, cut_layer=None): """A custom training loop for personalized FL.""" batch_size = config['batch_size'] log_interval = 10 logging.info("[Client #%d] Loading the dataset.", self.client_id) _train_loader = getattr(self, "train_loader", None) if callable(_train_loader): train_loader = self.train_loader(batch_size, trainset, sampler, cut_layer) else: train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = nn.CrossEntropyLoss() # Initializing the optimizer for the second stage of MAML # The learning rate here is the meta learning rate (beta) optimizer = torch.optim.SGD(self.model.parameters(), lr=Config().trainer.meta_learning_rate, momentum=Config().trainer.momentum, weight_decay=Config().trainer.weight_decay) # Initializing the schedule for meta learning rate, if necessary if hasattr(config, 'meta_lr_schedule'): meta_lr_schedule = optimizers.get_lr_schedule( optimizer, iterations_per_epoch, train_loader) else: meta_lr_schedule = None for epoch in range(1, epochs + 1): # Copy the current model due to using MAML current_model = copy.deepcopy(self.model) # Sending this model to the device used for training current_model.to(self.device) current_model.train() # Initializing the optimizer for the first stage of MAML # The learning rate here is the alpha in the paper temp_optimizer = torch.optim.SGD( current_model.parameters(), lr=Config().trainer.learning_rate, momentum=Config().trainer.momentum, weight_decay=Config().trainer.weight_decay) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule( temp_optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None # The first stage of MAML # Use half of the training dataset self.training_per_stage(1, temp_optimizer, lr_schedule, train_loader, cut_layer, current_model, loss_criterion, log_interval, epoch, epochs) # The second stage of MAML # Use the other half of the training dataset self.training_per_stage(2, optimizer, meta_lr_schedule, train_loader, cut_layer, self.model, loss_criterion, log_interval, epoch, epochs) if hasattr(optimizer, "params_state_update"): optimizer.params_state_update()
def train_model(self, config, trainset, sampler, cut_layer=None): """A custom trainer reporting training loss. """ log_interval = 10 batch_size = config['batch_size'] logging.info("[Client #%d] Loading the dataset.", self.client_id) train_loader = torch.utils.data.DataLoader(dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler) iterations_per_epoch = np.ceil(len(trainset) / batch_size).astype(int) epochs = config['epochs'] # Sending the model to the device used for training self.model.to(self.device) self.model.train() # Initializing the loss criterion _loss_criterion = getattr(self, "loss_criterion", None) if callable(_loss_criterion): loss_criterion = self.loss_criterion(self.model) else: loss_criterion = nn.CrossEntropyLoss() # Initializing the optimizer get_optimizer = getattr(self, "get_optimizer", optimizers.get_optimizer) optimizer = get_optimizer(self.model) # Initializing the learning rate schedule, if necessary if hasattr(config, 'lr_schedule'): lr_schedule = optimizers.get_lr_schedule(optimizer, iterations_per_epoch, train_loader) else: lr_schedule = None try: for epoch in range(1, epochs + 1): for batch_id, (examples, labels) in enumerate(train_loader): examples, labels = examples.to(self.device), labels.to( self.device) optimizer.zero_grad() if cut_layer is None: outputs = self.model(examples) else: outputs = self.model.forward_from(examples, cut_layer) loss = loss_criterion(outputs, labels) loss.backward() optimizer.step() if lr_schedule is not None: lr_schedule.step() if batch_id % log_interval == 0: if self.client_id == 0: logging.info( "[Server #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}" .format(os.getpid(), epoch, epochs, batch_id, len(train_loader), loss.data.item())) else: logging.info( "[Client #{}] Epoch: [{}/{}][{}/{}]\tLoss: {:.6f}" .format(self.client_id, epoch, epochs, batch_id, len(train_loader), loss.data.item())) if hasattr(optimizer, "params_state_update"): optimizer.params_state_update() except Exception as training_exception: logging.info("Training on client #%d failed.", self.client_id) raise training_exception if 'max_concurrency' in config: self.model.cpu() model_type = config['model_name'] filename = f"{model_type}_{self.client_id}_{config['run_id']}.pth" self.save_model(filename) # Save the training loss of the last epoch in this round model_name = config['model_name'] filename = f"{model_name}_{self.client_id}_{config['run_id']}.loss" Trainer.save_loss(loss.data.item(), filename)