Exemple #1
0
def load_pretrained_model(model_base, model_params_path):
    device = TorchUtils.get_device()
    state_dict = torch.load(model_params_path, map_location=device)
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']

        def strip_data_parallel(s):
            if s.startswith('module'):
                return s[len('module.'):]
            else:
                return s

        state_dict = {strip_data_parallel(k): v for k, v in state_dict.items()}
    model_base.load_state_dict(state_dict)
    model_base = model_base.to(device)
    return model_base
Exemple #2
0
    def __init__(self, attack, start_idx, end_idx, transform, **kwargs):
        """

        :param attack: the attack that will be activate on the original MNIST testset
        :param transform: use the transform (for dataset normalizations) prior to using the attack
        :param kwargs: initial init arguments for datasets.MNIST.
        """
        super(MnistAdversarialTest, self).__init__(**kwargs)
        assert (self.train is False)
        assert (start_idx >= 0 and end_idx < self.test_data.shape[0])
        test_samples = end_idx - start_idx + 1
        grp_size = 200
        assert (test_samples % grp_size == 0)
        plt_img_list_idx = list(range(0, 5))

        transform_data = torch.zeros([test_samples, 1, 28, 28])
        from utilities import plt_img
        plt_img(self.test_data, plt_img_list_idx)
        for index in range(int(test_samples)):
            # use the transform on all the testset
            img = Image.fromarray(self.test_data[index + start_idx].numpy(),
                                  mode='L')
            transform_data[index] = transform(img)

        self.adv_data = transform_data.clone()
        self.targets = torch.LongTensor(self.targets[start_idx:end_idx + 1])

        device = TorchUtils.get_device()
        epoch_start_time = time.time()
        for index in range(int(test_samples / grp_size)):
            # save the adversarial testset
            self.adv_data[index * grp_size:(index + 1) *
                          grp_size] = attack.create_adversarial_sample(
                              self.adv_data[index * grp_size:(index + 1) *
                                            grp_size].to(device),
                              self.targets[index * grp_size:(index + 1) *
                                           grp_size].to(device)).detach()
            print(
                "Create MNIST adversarial testset, index: {}, time passed: {}".
                format(index,
                       time.time() - epoch_start_time))

        self.transform = null_transform
        if attack.name != 'NoAttack':
            plt_img(self.adv_data, plt_img_list_idx, True)
Exemple #3
0
    def __init__(self, attack, start_idx, end_idx, transform, **kwargs):
        """

        :param attack: the attack that will be activate on the original MNIST testset
        :param transform: use the transform (for dataset normalizations) prior to using the attack
        :param kwargs: initial init arguments for datasets.MNIST.
        """
        super(CIFAR10AdversarialTest, self).__init__(**kwargs)
        assert (self.train is False)
        assert (start_idx >= 0 and end_idx < self.data.shape[0])
        test_samples = end_idx - start_idx + 1
        grp_size = 50
        assert (test_samples % grp_size == 0)

        test_adv_data = torch.zeros([test_samples, 3, 32, 32])
        from utilities import plt_img
        # plt_img(self.data, [0])
        for index in range(test_samples):
            # use the transform on all the testset
            img = Image.fromarray(self.data[index + start_idx])
            test_adv_data[index] = transform(img)
        plt_img(test_adv_data, [0])
        self.data = test_adv_data
        self.targets = torch.LongTensor(self.targets[start_idx:end_idx + 1])

        epoch_start_time = time.time()
        device = TorchUtils.get_device()
        for index in range(int(test_samples / grp_size)):
            # save the adversarial testset
            self.data[index * grp_size:(index + 1) *
                      grp_size] = attack.create_adversarial_sample(
                          self.data[index * grp_size:(index + 1) *
                                    grp_size].to(device),
                          self.targets[index * grp_size:(index + 1) *
                                       grp_size].to(device)).detach()
            print(
                "Create CIFAR adversarial testset, index: {}, time passed: {}".
                format(index,
                       time.time() - epoch_start_time))

        self.data = self.data.to("cpu")
        self.transform = null_transform
        plt_img(self.data, [0])
Exemple #4
0
    def train_model(self,
                    model,
                    dataloaders,
                    num_epochs: int = 10,
                    acc_goal=None,
                    eval_test_every_n_epoch: int = 1,
                    sample_test_data=None,
                    sample_test_true_label=None):
        """
        Train DNN model using some trainset.
        :param model: the model which will be trained.
        :param dataloaders: contains the trainset for training and testset for evaluation.
        :param num_epochs: number of epochs to train the model.
        :param acc_goal: stop training when getting to this accuracy rate on the trainset.
        :return: trained model (also the training of the models happen inplace)
                 and the loss of the trainset and testset.
        """
        self.logger.info("Use device:" + TorchUtils.get_device())
        model = TorchUtils.to_device(model)
        attack = get_attack(
            self.attack_params, model,
            get_dataset_min_max_val(dataloaders['dataset_name']))

        # If testset is already adversarial then do nothing else use the same attack to generate adversarial testset
        testset_attack = get_attack({
            "attack_type": "no_attack"
        }) if dataloaders[
            'adv_test_flag'] else attack  # TODO: replace training attack with testing attack
        epoch_time = 0

        # Loop on epochs
        for epoch in range(1, num_epochs + 1):

            epoch_start_time = time.time()
            total_loss_in_epoch, natural_train_loss, train_acc = self.__train(
                model, dataloaders['train'], attack, sample_test_data,
                sample_test_true_label)
            # Evaluate testset
            if self.eval_test_during_train is True and epoch % eval_test_every_n_epoch == 0:
                test_loss, test_acc = self.eval_model(model,
                                                      dataloaders['test'],
                                                      testset_attack)
            else:
                test_loss, test_acc = torch.tensor([-1.]), torch.tensor([-1.])
            epoch_time = time.time() - epoch_start_time

            # Save model
            if epoch % self.save_model_every_n_epoch == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(self.logger.output_folder,
                                 'model_iter_%d.pt' % epoch))
            # Log
            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            self.logger.info(
                '[%d/%d] [train test] loss =[%f %f] natural_train_loss=[%f], acc=[%f %f], lr=%f, epoch_time=%.2f'
                % (epoch, num_epochs, total_loss_in_epoch, test_loss,
                   natural_train_loss, train_acc, test_acc, lr, epoch_time))

            # Stop training if desired goal is achieved
            if acc_goal is not None and train_acc >= acc_goal:
                break

        test_loss, test_acc = self.eval_model(model, dataloaders['test'],
                                              testset_attack)
        train_loss_output = float(total_loss_in_epoch)
        test_loss_output = float(test_loss)
        # Print and save
        self.logger.info(
            '----- [train test] loss =[%f %f], natural_train_loss=[%f], acc=[%f %f] epoch_time=%.2f'
            % (total_loss_in_epoch, test_loss, natural_train_loss, train_acc,
               test_acc, epoch_time))

        return model, train_loss_output, test_loss_output
Exemple #5
0
def _main():
    mp.set_start_method('spawn')
    parser = jsonargparse.ArgumentParser(description='General arguments',
                                         default_meta=False)
    parser.add_argument('-t',
                        '--general.experiment_type',
                        default='imagenet_adversarial',
                        help='Type of experiment to execute',
                        type=str)
    parser.add_argument(
        '-p',
        '--general.param_file_path',
        default=os.path.join('./src/parameters', 'eval_imagenet_param.json'),
        help=
        'param file path used to load the parameters file containing default values to all '
        'parameters',
        type=str)
    # parser.add_argument('-p', '--general.param_file_path', default='src/tests/test_mnist_pgd_with_pnml_expected_result/params.json',
    #                     help='param file path used to load the parameters file containing default values to all '
    #                          'parameters', type=str)
    parser.add_argument(
        '-o',
        '--general.output_root',
        default='output',
        help='the output directory where results will be saved',
        type=str)
    parser.add_argument('--adv_attack_test.attack_type',
                        help='attack type',
                        type=str,
                        default="natural")
    parser.add_argument('-f',
                        '--adv_attack_test.test_start_idx',
                        help='first test idx',
                        type=int)
    parser.add_argument('-l',
                        '--adv_attack_test.test_end_idx',
                        help='last test idx',
                        type=int)
    parser.add_argument('-e',
                        '--adv_attack_test.epsilon',
                        help='the epsilon strength of the attack',
                        type=float)
    parser.add_argument('-ts',
                        '--adv_attack_test.pgd_step',
                        help='the step size of the attack',
                        type=float)
    parser.add_argument('-ti',
                        '--adv_attack_test.pgd_iter',
                        help='the number of test pgd iterations',
                        type=int)
    parser.add_argument(
        '-b',
        '--adv_attack_test.beta',
        help='the beta value for regret reduction regularization ',
        type=float)
    parser.add_argument('-r',
                        '--fit_to_sample.epsilon',
                        help='the epsilon strength of the refinement (lambda)',
                        type=float)
    parser.add_argument('-i',
                        '--fit_to_sample.pgd_iter',
                        help='the number of PGD iterations of the refinement',
                        type=int)
    parser.add_argument('-s',
                        '--fit_to_sample.pgd_step',
                        help='the step size of the refinement',
                        type=float)
    parser.add_argument(
        '-n',
        '--fit_to_sample.pgd_test_restart_num',
        help='the number of PGD restarts where 0 means no random start',
        type=int)
    args = jsonargparse.namespace_to_dict(parser.parse_args())
    general_args = args.pop('general')

    experiment_h = Experiment(general_args, args)
    dataloaders = experiment_h.get_adv_dataloaders()
    ################
    # Create logger and save params to output folder
    logger_utilities.init_logger(logger_name=experiment_h.get_exp_name(),
                                 output_root=experiment_h.output_dir)
    logger = logger_utilities.get_logger()
    # logger = Logger(experiment_type='TMP', output_root='output')
    logger.info('OutputDirectory: %s' % logger.output_folder)
    logger.info('Device: %s' % TorchUtils.get_device())
    logger.info(experiment_h.get_params())
    eval_dataset(dataloaders['test'], 29, logger, experiment_h)
    logger.info("Done")