def update_weights(self, model, global_round, idx_user): # Set mode to train model # model.to(self.device) # model.train() epoch_loss = [] total_norm = [] loss_list = [] conv_grad = [] fc_grad = [] # Set optimizer for the local updates if self.args.optimizer == 'sgd_bench': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.9) elif self.args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4) elif self.args.optimizer == 'sgd_vc': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, weight_decay=1e-4, momentum=0.9) elif self.args.optimizer == 'sam': base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update optimizer = SAM(model.parameters(), base_optimizer, lr=self.args.lr, momentum=0.9, weight_decay=1e-4) elif self.args.optimizer == 'no_weight_decay': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr) elif self.args.optimizer == 'clip': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, weight_decay=1e-4) elif self.args.optimizer == 'resnet': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4) elif self.args.optimizer == 'no_momentum': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, weight_decay=1e-4) elif self.args.optimizer == 'clip_nf': optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4) if 'resnet' in self.args.model: optimizer = AGC(model.parameters(), optimizer, model=model, ignore_agc=['fc'], clipping=1e-3) else: optimizer = AGC(model.parameters(), optimizer, model=model, ignore_agc=['fc1', 'fc2', 'fc3'], clipping=1e-3) # optimizer = SGD_AGC(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4, clipping=1e-3) for iter in range(self.args.local_ep): batch_loss = [] for batch_idx, (images, labels) in enumerate(self.trainloader): images, labels = images.to(self.device), labels.to(self.device) optimizer.zero_grad() log_probs = model(images) loss = self.criterion(log_probs, labels) if self.args.verbose == 0: del images del labels torch.cuda.empty_cache() loss.backward() # gradient 확인용 - how does BN conv_grad.append(model.conv1.weight.grad.clone().to('cpu')) if self.args.optimizer != 'clip': total_norm.append(check_norm(model)) if self.args.model == 'cnn' or self.args.model == 'cnn_ws': fc_grad.append(model.fc3.weight.grad.clone().to('cpu')) else: fc_grad.append(model.fc.weight.grad.clone().to('cpu')) if self.args.optimizer == 'sam': optimizer.first_step(zero_grad=True) log_probs = model(images) loss = self.criterion(log_probs, labels) loss.backward() optimizer.second_step(zero_grad=True) elif self.args.optimizer == 'clip': max_norm = 0.3 if self.args.lr == 5: max_norm = 0.08 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) total_norm.append(check_norm(model)) optimizer.step() else: # sam이 아닌 경우 optimizer.step() # print(optimizer.param_groups[0]['lr']) # - lr decay 체크용 if self.args.verbose: print( '|Client : {} Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(idx_user, global_round + 1, iter + 1, batch_idx * len(images), len(self.trainloader.dataset), 100. * batch_idx / len(self.trainloader), loss.item())) # self.logger.add_scalar('loss', loss.item()) batch_loss.append(loss.item()) # itr loss 확인용 - how does BN loss_list.append(loss.item()) print(total_norm) # gradient 확인용 epoch_loss.append(sum(batch_loss) / len(batch_loss)) return model.state_dict(), sum(epoch_loss) / len( epoch_loss), loss_list, conv_grad, fc_grad, total_norm
momentum=args.momentum, weight_decay=args.weight_decay) scheduler = StepLR(optimizer, args.learning_rate, args.epochs) for epoch in range(args.epochs): model.train() log.train(len_dataset=len(dataset.train)) for batch in dataset.train: inputs, targets = (b.to(device) for b in batch) # first forward-backward step predictions = model(inputs) loss = smooth_crossentropy(predictions, targets) loss.mean().backward() optimizer.first_step(zero_grad=True) # second forward-backward step smooth_crossentropy(model(inputs), targets).mean().backward() optimizer.second_step(zero_grad=True) with torch.no_grad(): correct = torch.argmax(predictions.data, 1) == targets log(model, loss.cpu(), correct.cpu(), scheduler.lr()) scheduler(epoch) model.eval() log.eval(len_dataset=len(dataset.test)) with torch.no_grad(): for batch in dataset.test:
def train(model, n_epochs, learningrate, train_loader, test_loader, use_sam=False): # optimizer if use_sam: optimizer = SAM(filter(lambda p: p.requires_grad, model.parameters()), optim.Adam, lr=learningrate) else: optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learningrate) # scheduler #scheduler = StepLR(optimizer, step_size=1, gamma=gamma) best_acc = 0 best_model = None for epoch in range(n_epochs): epoch_loss = 0 epoch_accuracy = 0 model.train() for data, label in tqdm(train_loader): data = data.to(device) label = label.to(device) output = model(data) loss = criterion(output, label) if use_sam: #optimizer.zero_grad() loss.backward() optimizer.first_step(zero_grad=True) # second forward-backward pass output = model(data) loss = criterion(output, label) loss.backward() optimizer.second_step(zero_grad=True) else: optimizer.zero_grad() loss.backward() optimizer.step() acc = (output.argmax(dim=1) == label).float().mean() epoch_accuracy += acc / len(train_loader) epoch_loss += loss / len(train_loader) model.eval() with torch.no_grad(): epoch_val_accuracy = 0 epoch_val_loss = 0 epoch_Positive = 0 epoch_Negative = 0 epoch_TP = 0 epoch_FP = 0 epoch_TN = 0 epoch_FN = 0 for data, label in tqdm(test_loader): data = data.to(device) label = label.to(device) val_output = model(data) val_loss = criterion(val_output, label) acc = (val_output.argmax(dim=1) == label).float().mean() epoch_val_accuracy += acc / len(test_loader) epoch_val_loss += val_loss / len(test_loader) c_True_Positive, c_False_Positive, c_True_Negative, c_False_Negative, c_Positive, c_Negative = evaluate( val_output, label) epoch_TP += c_True_Positive epoch_FP += c_False_Positive epoch_TN += c_True_Negative epoch_FN += c_False_Negative epoch_Positive += c_Positive epoch_Negative += c_Negative Recall = (epoch_TP) / (epoch_TP + epoch_FN) Precision = (epoch_TP) / (epoch_TP + epoch_FP) F1 = (2 * (Recall * Precision)) / (Recall + Precision) print( f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" ) print( "Recall: {Recall:.4f}, Precision: {Precision:.4f}, F1 Score: {F1:.4f}" ) if best_acc < epoch_val_accuracy: best_acc = epoch_val_accuracy best_model = copy.deepcopy(model.state_dict()) #scheduler.step() if best_model is not None: model.load_state_dict(best_model) print(f"Best acc:{best_acc}") model.eval() with torch.no_grad(): epoch_val_accuracy = 0 epoch_val_loss = 0 for data, label in test_loader: data = data.to(device) label = label.to(device) val_output = model(data) val_loss = criterion(val_output, label) acc = (val_output.argmax(dim=1) == label).float().mean() epoch_val_accuracy += acc / len(test_loader) epoch_val_loss += val_loss / len(test_loader) print( f"val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" ) else: print(f"No best model Best acc:{best_acc}")