Exemplo n.º 1
0
    def generate(self,
                 train_loader=None,
                 valid_loader=None,
                 defense_enhanced_saver=None):
        '''
        @description: 
        @param {
            train_loader:
            valid_loader:
        } 
        @return: best_model_weights, best_acc
        '''
        best_val_acc = None
        best_model_weights = self.model.state_dict()
        dir_path = os.path.dirname(defense_enhanced_saver)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        for epoch in range(self.num_epochs):
            self.train(train_loader, epoch)
            val_acc = self.valid(valid_loader)
            adjust_learning_rate(epoch=epoch, optimizer=self.optimizer)
            if not best_val_acc or round(val_acc, 4) >= round(best_val_acc, 4):
                if best_val_acc is not None:
                    os.remove(defense_enhanced_saver)
                best_val_acc = val_acc
                best_model_weights = self.model.state_dict()
                torch.save(self.model.state_dict(), defense_enhanced_saver)
            else:
                print(
                    'Train Epoch{:>3}: validation dataset accuracy did not improve from {:.4f}\n'
                    .format(epoch, best_val_acc))
        print('Best val Acc: {:.4f}'.format(best_val_acc))
        return best_model_weights, best_val_acc
Exemplo n.º 2
0
    def generate(self,
                 train_loader=None,
                 valid_loader=None,
                 defense_enhanced_saver=None):
        '''
        @description:
        @param {
            train_loader:
            valid_loader:
        }
        @return: best_model_weights, best_acc
        '''
        dir_path = os.path.dirname(defense_enhanced_saver)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)
        #print(self.train_externals)
        if self.train_externals:
            print('\nStart to train the external models ......\n')
            self.train_external_model_group(
                train_loader=train_loader,
                validation_loader=valid_loader,
                model_dir_path=self.external_model_path)

            # load the external models
        pre_train_models = self.load_external_model_group(
            model_dir=self.external_model_path)

        best_val_acc = None
        best_model_weights = self.model.state_dict()

        for epoch in range(self.num_epochs):
            #if not self.scheduler:
            #self.scheduler.step()

            self.train(pre_train_models, train_loader, epoch)
            val_acc = self.valid(self.model, valid_loader)

            adjust_learning_rate(epoch=epoch, optimizer=self.optimizer)

            if not best_val_acc or round(val_acc, 4) >= round(best_val_acc, 4):
                if best_val_acc is not None:
                    os.remove(defense_enhanced_saver)
                best_val_acc = val_acc
                best_model_weights = self.model.state_dict()
                torch.save(self.model.state_dict(), defense_enhanced_saver)
            else:
                print(
                    'Train Epoch{:>3}: validation dataset accuracy did not improve from {:.4f}\n'
                    .format(epoch, best_val_acc))

            # if round(val_acc, 4) >= round(best_acc, 4):
            #     best_acc = val_acc
            #     best_model_weights = self.model.state_dict()

        print('Best val Acc: {:.4f}'.format(best_val_acc))
        return best_model_weights, best_val_acc
Exemplo n.º 3
0
    def train_external_model_group(self,
                                   train_loader=None,
                                   validation_loader=None,
                                   model_dir_path=None):
        """

        :param train_loader:
        :param validation_loader:
        :return:
        """
        # Set up the model group with 4 static external models

        if self.dataset == "CIFAR10":
            model_group = [CIFAR10_A(), CIFAR10_B(), CIFAR10_C(), CIFAR10_D()]
        elif self.dataset == "ImageNet":
            model_group = [
                ImageNet_A(),
                ImageNet_B(),
                ImageNet_C(),
                ImageNet_D()
            ]

        model_group = [model.to(self.device) for model in model_group]

        # training the models in model_group one by one
        for i in range(len(model_group)):

            # prepare the optimizer for CIFAR10
            if i == 3:
                optimizer_external = optim.SGD(model_group[i].parameters(),
                                               lr=0.001,
                                               momentum=0.9,
                                               weight_decay=1e-6)
            else:
                optimizer_external = optim.Adam(model_group[i].parameters(),
                                                lr=self.learn_rate)

            #scheduler_external = optim.lr_scheduler.StepLR(optimizer_external, 20, gamma=0.1)

            print('\nwe are training the {}-th static external model ......'.
                  format(i))
            best_val_acc = None
            for index_epoch in range(self.num_epochs):

                #scheduler_external.step()
                #print("external model learn rate is: ", scheduler_external.get_lr()[0])

                self.train_one_epoch(model=model_group[i],
                                     train_loader=train_loader,
                                     optimizer=optimizer_external,
                                     epoch=index_epoch,
                                     device=self.device)
                val_acc = self.valid(model=model_group[i],
                                     valid_loader=validation_loader)

                adjust_learning_rate(epoch=index_epoch,
                                     optimizer=optimizer_external)
                #print(model_dir_path)
                assert os.path.exists(model_dir_path)
                defense_external_saver = os.path.join(
                    model_dir_path,
                    "{}_EAT_{}.pt".format(self.dataset, str(i)))
                if not best_val_acc or round(val_acc, 4) >= round(
                        best_val_acc, 4):
                    if best_val_acc is not None:
                        os.remove(defense_external_saver)
                    best_val_acc = val_acc
                    model_group[i].save(name=defense_external_saver)
                else:
                    print(
                        'Train Epoch {:>3}: validation dataset accuracy did not improve from {:.4f}\n'
                        .format(index_epoch, best_val_acc))