示例#1
0
    def update_weights(self, model, global_round, idx_user):
        # Set mode to train model
        #        model.to(self.device)
        #        model.train()
        epoch_loss = []
        total_norm = []
        loss_list = []
        conv_grad = []
        fc_grad = []
        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd_bench':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.args.lr,
                                         weight_decay=1e-4)
        elif self.args.optimizer == 'sgd_vc':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4,
                                        momentum=0.9)
        elif self.args.optimizer == 'sam':
            base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
            optimizer = SAM(model.parameters(),
                            base_optimizer,
                            lr=self.args.lr,
                            momentum=0.9,
                            weight_decay=1e-4)
        elif self.args.optimizer == 'no_weight_decay':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr)
        elif self.args.optimizer == 'clip':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4)
        elif self.args.optimizer == 'resnet':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
        elif self.args.optimizer == 'no_momentum':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        weight_decay=1e-4)
        elif self.args.optimizer == 'clip_nf':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            if 'resnet' in self.args.model:
                optimizer = AGC(model.parameters(),
                                optimizer,
                                model=model,
                                ignore_agc=['fc'],
                                clipping=1e-3)
            else:
                optimizer = AGC(model.parameters(),
                                optimizer,
                                model=model,
                                ignore_agc=['fc1', 'fc2', 'fc3'],
                                clipping=1e-3)
            # optimizer = SGD_AGC(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4, clipping=1e-3)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                if self.args.verbose == 0:
                    del images
                    del labels
                    torch.cuda.empty_cache()

                loss.backward()

                # gradient 확인용 - how does BN
                conv_grad.append(model.conv1.weight.grad.clone().to('cpu'))
                if self.args.optimizer != 'clip':
                    total_norm.append(check_norm(model))

                if self.args.model == 'cnn' or self.args.model == 'cnn_ws':
                    fc_grad.append(model.fc3.weight.grad.clone().to('cpu'))
                else:
                    fc_grad.append(model.fc.weight.grad.clone().to('cpu'))

                if self.args.optimizer == 'sam':
                    optimizer.first_step(zero_grad=True)
                    log_probs = model(images)
                    loss = self.criterion(log_probs, labels)
                    loss.backward()
                    optimizer.second_step(zero_grad=True)
                elif self.args.optimizer == 'clip':
                    max_norm = 0.3
                    if self.args.lr == 5:
                        max_norm = 0.08
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_norm)
                    total_norm.append(check_norm(model))
                    optimizer.step()
                else:  # sam이 아닌 경우
                    optimizer.step()
                # print(optimizer.param_groups[0]['lr']) # - lr decay 체크용
                if self.args.verbose:
                    print(
                        '|Client : {} Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(idx_user, global_round + 1, iter + 1,
                                batch_idx * len(images),
                                len(self.trainloader.dataset),
                                100. * batch_idx / len(self.trainloader),
                                loss.item()))
                # self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
                # itr loss 확인용 - how does BN
                loss_list.append(loss.item())
            print(total_norm)  # gradient 확인용
            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(
            epoch_loss), loss_list, conv_grad, fc_grad, total_norm
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            smooth_crossentropy(model(inputs), targets).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
示例#3
0
def train(model,
          n_epochs,
          learningrate,
          train_loader,
          test_loader,
          use_sam=False):
    # optimizer
    if use_sam:
        optimizer = SAM(filter(lambda p: p.requires_grad, model.parameters()),
                        optim.Adam,
                        lr=learningrate)
    else:
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=learningrate)
    # scheduler
    #scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    best_acc = 0
    best_model = None
    for epoch in range(n_epochs):
        epoch_loss = 0
        epoch_accuracy = 0
        model.train()
        for data, label in tqdm(train_loader):
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)

            if use_sam:
                #optimizer.zero_grad()
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                output = model(data)
                loss = criterion(output, label)
                loss.backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        model.eval()
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            epoch_Positive = 0
            epoch_Negative = 0
            epoch_TP = 0
            epoch_FP = 0
            epoch_TN = 0
            epoch_FN = 0
            for data, label in tqdm(test_loader):
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(test_loader)
                epoch_val_loss += val_loss / len(test_loader)
                c_True_Positive, c_False_Positive, c_True_Negative, c_False_Negative, c_Positive, c_Negative = evaluate(
                    val_output, label)
                epoch_TP += c_True_Positive
                epoch_FP += c_False_Positive
                epoch_TN += c_True_Negative
                epoch_FN += c_False_Negative
                epoch_Positive += c_Positive
                epoch_Negative += c_Negative
            Recall = (epoch_TP) / (epoch_TP + epoch_FN)
            Precision = (epoch_TP) / (epoch_TP + epoch_FP)
            F1 = (2 * (Recall * Precision)) / (Recall + Precision)

        print(
            f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )
        print(
            "Recall: {Recall:.4f},  Precision: {Precision:.4f}, F1 Score: {F1:.4f}"
        )
        if best_acc < epoch_val_accuracy:
            best_acc = epoch_val_accuracy
            best_model = copy.deepcopy(model.state_dict())
        #scheduler.step()

    if best_model is not None:
        model.load_state_dict(best_model)
        print(f"Best acc:{best_acc}")
        model.eval()
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in test_loader:
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(test_loader)
                epoch_val_loss += val_loss / len(test_loader)

        print(
            f"val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )
    else:
        print(f"No best model Best acc:{best_acc}")