Ejemplo n.º 1
0
    def __init__(self, args):
        super().__init__(args)

        self.in_channels = 1 if self.dataset == 'mnist' else 3
        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.dataset]

        self.mean = torch.tensor([0.4914, 0.4822, 0.4465])
        self.std = torch.tensor([0.2023, 0.1994, 0.2010])

        self.train_data, self.valid_data = prepare_dataset(self.args)
        self.wm_data = None

        if self.use_trigger_as_passport:
            self.passport_data = prepare_wm('data/trigger_set/pics')
        else:
            self.passport_data = self.valid_data

        if self.train_backdoor:
            self.wm_data = prepare_wm('data/trigger_set/pics')

        self.construct_model()

        optimizer = optim.SGD(self.model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]
               ) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer, self.lr_config[self.lr_config['type']],
                self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(self.model, optimizer, scheduler, self.device)

        if self.is_tl:
            self.finetune_load()
        else:
            self.makedirs_or_load()
Ejemplo n.º 2
0
class ClassificationExperiment(Experiment):
    def __init__(self, args):
        super().__init__(args)

        self.in_channels = 1 if self.dataset == 'mnist' else 3
        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.dataset]

        self.mean = torch.tensor([0.4914, 0.4822, 0.4465])
        self.std = torch.tensor([0.2023, 0.1994, 0.2010])

        self.train_data, self.valid_data = prepare_dataset(self.args)
        self.wm_data = None

        if self.use_trigger_as_passport:
            self.passport_data = prepare_wm('data/trigger_set/pics')
        else:
            self.passport_data = self.valid_data

        if self.train_backdoor:
            self.wm_data = prepare_wm('data/trigger_set/pics')

        self.construct_model()

        optimizer = optim.SGD(self.model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(self.model, optimizer, scheduler, self.device)

        if self.is_tl:
            self.finetune_load()
        else:
            self.makedirs_or_load()

    def construct_model(self):
        def setup_keys():
            if self.key_type != 'random':
                if self.arch == 'alexnet':
                    pretrained_model = AlexNetNormal(self.in_channels, self.num_classes)
                else:
                    pretrained_model = ResNet18(num_classes=self.num_classes,
                                                norm_type=self.norm_type)

                pretrained_model.load_state_dict(torch.load(self.pretrained_path))
                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        def load_pretrained():
            if self.pretrained_path is not None:
                sd = torch.load(self.pretrained_path)
                model.load_state_dict(sd)

        if self.train_passport:
            passport_kwargs = construct_passport_kwargs(self)
            self.passport_kwargs = passport_kwargs

            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetPassport(self.in_channels, self.num_classes, passport_kwargs)
            else:
                model = ResNet18Passport(num_classes=self.num_classes,
                                         passport_kwargs=passport_kwargs)
            self.model = model.to(self.device)

            setup_keys()
        else:  # train normally or train backdoor
            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type)
            else:
                model = ResNet18(num_classes=self.num_classes, norm_type=self.norm_type)

            load_pretrained()
            self.model = model.to(self.device)

        pprint(self.model)

    def setup_keys(self, pretrained_model):
        if self.key_type != 'random':
            n = 1 if self.key_type == 'image' else 20  # any number will do

            key_x, x_inds = passport_generator.get_key(self.passport_data, n)
            key_x = key_x.to(self.device)
            key_y, y_inds = passport_generator.get_key(self.passport_data, n)
            key_y = key_y.to(self.device)

            passport_generator.set_key(pretrained_model, self.model,
                                       key_x, key_y)

    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            clone_model = AlexNetNormal(self.in_channels,
                                        self.num_classes,
                                        self.norm_type)
        else:
            clone_model = ResNet18(num_classes=self.num_classes,
                                   norm_type=self.norm_type)

        ##### load / reset weights of passport layers for clone model #####
        try:
            clone_model.load_state_dict(self.model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for clone_m, self_m in zip(clone_model.features, self.model.features):
                    try:
                        clone_m.load_state_dict(self_m.state_dict())
                    except:
                        print('Having problem to load state dict usually caused by missing keys, load by strict=False')
                        clone_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

            else:
                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock
                                clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: PassportBlock

                                try:
                                    clone_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    clone_m.load_state_dict(self_m.state_dict(), False)
                                    clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                                    clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        clone_m = clone_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            clone_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            clone_m.load_state_dict(self_m.state_dict(), False)
                            clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                            clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        clone_model.to(self.device)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                clone_model.classifier.reset_parameters()
            except:
                clone_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(clone_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(clone_model,
                               optimizer,
                               scheduler,
                               self.device)
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(clone_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for clone_m, self_m in zip(clone_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(clone_m.state_dict())
                        except:
                            self_m.load_state_dict(clone_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(clone_m.state_dict())
                                    except:
                                        self_m.load_state_dict(clone_m.state_dict(), False)
                        else:
                            clone_m = clone_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(clone_m.state_dict())
                            except:
                                self_m.load_state_dict(clone_m.state_dict(), False)

            clone_model.to(self.device)
            self.model.to(self.device)

            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            if self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', clone_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', clone_model)

            self.save_last_model()

    def training(self):
        best_acc = float('-inf')

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True

        if self.save_interval > 0:
            self.save_model('epoch-0.pth')

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data, self.wm_data)
            print(f'Sign Detection Accuracy: {train_metrics["sign_acc"] * 100:6.4f}')

            valid_metrics = self.trainer.test(self.valid_data, 'Testing Result')

            wm_metrics = {}

            if self.train_backdoor:
                wm_metrics = self.trainer.test(self.wm_data, 'WM Result')

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'wm_{key}'] = wm_metrics[key]

            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')

            self.save_last_model()

    def evaluate(self):
        self.trainer.test(self.valid_data)
Ejemplo n.º 3
0
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            clone_model = AlexNetNormal(self.in_channels,
                                        self.num_classes,
                                        self.norm_type)
        else:
            clone_model = ResNet18(num_classes=self.num_classes,
                                   norm_type=self.norm_type)

        ##### load / reset weights of passport layers for clone model #####
        try:
            clone_model.load_state_dict(self.model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for clone_m, self_m in zip(clone_model.features, self.model.features):
                    try:
                        clone_m.load_state_dict(self_m.state_dict())
                    except:
                        print('Having problem to load state dict usually caused by missing keys, load by strict=False')
                        clone_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

            else:
                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock
                                clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: PassportBlock

                                try:
                                    clone_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    clone_m.load_state_dict(self_m.state_dict(), False)
                                    clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                                    clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        clone_m = clone_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            clone_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            clone_m.load_state_dict(self_m.state_dict(), False)
                            clone_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                            clone_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        clone_model.to(self.device)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                clone_model.classifier.reset_parameters()
            except:
                clone_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(clone_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(clone_model,
                               optimizer,
                               scheduler,
                               self.device)
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(clone_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for clone_m, self_m in zip(clone_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(clone_m.state_dict())
                        except:
                            self_m.load_state_dict(clone_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    clone_m = clone_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(clone_m.state_dict())
                                    except:
                                        self_m.load_state_dict(clone_m.state_dict(), False)
                        else:
                            clone_m = clone_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(clone_m.state_dict())
                            except:
                                self_m.load_state_dict(clone_m.state_dict(), False)

            clone_model.to(self.device)
            self.model.to(self.device)

            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            if self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', clone_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', clone_model)

            self.save_last_model()
Ejemplo n.º 4
0
def main(args):
    if args.dir is None:
        print(
            ' ---------- Please mention current experiment directory ----------'
        )
        return

    # Directory of current experiment
    base_dir = os.path.dirname(os.path.realpath(__file__))
    experiment_dir = os.path.join(base_dir, args.agent_type, args.dir)

    # load traing/testing parameters
    params = load_parameters(file=os.path.join(experiment_dir, 'params.dat'))
    # print('env: ', params.environment)
    # print('action_repeat: ', params.action_repeat)
    # print('agent: ', params.agent)
    # print('training_episodes: ', params.training_episodes)
    # print('training_steps_per_episode: ', params.training_steps_per_episode)
    # print('testing_episodes: ', params.testing_episodes)
    # print('testing_steps_per_episode: ', params.testing_steps_per_episode)
    # print('epsilon_start: ', params.hyperparameters['epsilon_start'])
    # print('epsilon_end: ', params.hyperparameters['epsilon_end'])
    # print('epsilon_decay: ', params.hyperparameters['epsilon_decay'])
    # print('epsilon_steps: ', params.hyperparameters['epsilon_steps'])
    # print('use_cuda: ', params.hyperparameters['use_cuda'])
    # print('learning_rate: ', params.hyperparameters['learning_rate'])
    # print('batch_size: ', params.hyperparameters['batch_size'])
    # print('discount_rate: ', params.hyperparameters['discount_rate'])
    # print('target_network_update_frequency: ', params.hyperparameters['target_network_update_frequency'])

    # Initialize the environment
    env = gym.make(params.environment)
    state_size = env.observation_space.shape
    action_size = env.action_space.n

    # Initialize the agent
    agent = DQNAgent(state_size=state_size,
                     action_size=action_size,
                     hyperparameters=params.hyperparameters)

    if args.train:
        trainer = Trainer(env=env,
                          agent=agent,
                          params=params,
                          exp_dir=experiment_dir)

        try:
            trainer.train()

        except KeyboardInterrupt:
            trainer.close()
            sys.exit(0)

        finally:
            print('\ndone.')

    if args.retrain:
        trainer = Trainer(env=env,
                          agent=agent,
                          params=params,
                          exp_dir=experiment_dir,
                          retrain=True)

        try:
            trainer.retrain()

        except KeyboardInterrupt:
            trainer.close()
            sys.exit(0)

        finally:
            print('\ndone.')

    if args.test:
        tester = Tester(env=env,
                        agent=agent,
                        params=params,
                        exp_dir=experiment_dir)

        try:
            tester.test()

        except KeyboardInterrupt:
            try:
                tester.close()
                sys.exit(0)
            except SystemExit:
                tester.close()
                os._exit(0)
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        if self.tl_dataset == 'caltech-101':
            self.num_classes = 101
        elif self.tl_dataset == 'cifar100':
            self.num_classes = 100
        elif self.tl_dataset == 'caltech-256':
            self.num_classes = 257
        else:  # cifar10
            self.num_classes = 10

        # load clone model
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type)

        # # 自己的更改,fine-tune alex 一路
        # if self.arch == 'alexnet':
        #     tl_model = AlexNetPassportPrivate(self.in_channels, self.num_classes, passport_kwargs)
        # else:
        #     tl_model = ResNet18Private(num_classes=self.num_classes, passport_kwargs=passport_kwargs)
        #

        ##### load / reset weights of passport layers for clone model #####

        try:
            tl_model.load_state_dict(self.model.state_dict())
            # tl_model.load_state_dict(self.copy_model.state_dict())
        except:
            print('Having problem to direct load state dict, loading it manually')
            if self.arch == 'alexnet':
                for tl_m, self_m in zip(tl_model.features, self.model.features):

                    try:
                        tl_m.load_state_dict(self_m.state_dict())
                    except:
                        print(
                            'Having problem to load state dict usually caused by missing keys, load by strict=False')
                        tl_m.load_state_dict(self_m.state_dict(), False)  # load conv weight, bn running mean
                        # print(self_m)
                        # print(tl_m)
                        # 原来的参数载入
                        # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))
                        # tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                        #更改,注意bn的值
                        scale1,scale2 = self_m.get_scale()
                        tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                        tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))



            else:

                passport_settings = self.passport_config
                for l_key in passport_settings:  # layer
                    if isinstance(passport_settings[l_key], dict):
                        for i in passport_settings[l_key]:  # sequential
                            for m_key in passport_settings[l_key][i]:  # convblock

                                tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)  # type: ConvBlock
                                self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                try:
                                    tl_m.load_state_dict(self_m.state_dict())
                                except:
                                    print(f'{l_key}.{i}.{m_key} cannot load state dict directly')
                                    # print(self_m)
                                    # print(tl_m)
                                    tl_m.load_state_dict(self_m.state_dict(), False)

                                    scale1, scale2 = self_m.get_scale()
                                    tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                                    tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

                    else:
                        print("FFFFFFFFFFFFFFFFFFFFFFF")
                        tl_m = tl_model.__getattr__(l_key)
                        self_m = self.model.__getattr__(l_key)

                        try:
                            tl_m.load_state_dict(self_m.state_dict())
                        except:
                            print(f'{l_key} cannot load state dict directly')
                            tl_m.load_state_dict(self_m.state_dict(), False)
                            # tl_m.bn.weight.data.copy_(self_m.get_scale().detach().view(-1))

                            scale1, scale2 = self_m.get_scale()
                            tl_m.bn.weight.data.copy_(scale1.detach().view(-1))
                            tl_m.bn.bias.data.copy_(self_m.get_bias().detach().view(-1))

        tl_model.to(self.device)
        print('Loaded clone model')

        # tl scheme setup
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        # for name, m in self.model.named_modules():
        #     print('name',name)
        #     if name
        #
        # for i in self.model.fc.parameters():
        #     i.requires_grad = False
        #
        # for i in self.model.bn1.parameters():
        #     i.requires_grad = False


        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        tl_trainer = Trainer(tl_model,
                             optimizer,
                             scheduler,
                             self.device)

        tester = TesterPrivate(self.model,
                               self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0
        best_file = os.path.join(self.logdir, 'best.txt')
        best_ep = 1

        for ep in range(1, self.epochs + 1):
            train_metrics = tl_trainer.train(ep, self.train_data)
            valid_metrics = tl_trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            try:
                self.model.load_state_dict(tl_model.state_dict())
            except:
                if self.arch == 'alexnet':
                    for tl_m, self_m in zip(tl_model.features, self.model.features):
                        try:
                            self_m.load_state_dict(tl_m.state_dict())
                        except:
                            self_m.load_state_dict(tl_m.state_dict(), False)
                else:
                    passport_settings = self.passport_config
                    for l_key in passport_settings:  # layer
                        if isinstance(passport_settings[l_key], dict):
                            for i in passport_settings[l_key]:  # sequential
                                for m_key in passport_settings[l_key][i]:  # convblock
                                    tl_m = tl_model.__getattr__(l_key)[int(i)].__getattr__(m_key)
                                    self_m = self.model.__getattr__(l_key)[int(i)].__getattr__(m_key)

                                    try:
                                        self_m.load_state_dict(tl_m.state_dict())
                                    except:
                                        self_m.load_state_dict(tl_m.state_dict(), False)
                        else:
                            tl_m = tl_model.__getattr__(l_key)
                            self_m = self.model.__getattr__(l_key)

                            try:
                                self_m.load_state_dict(tl_m.state_dict())
                            except:
                                self_m.load_state_dict(tl_m.state_dict(), False)

            wm_metrics = tester.test_signature()
            L = len(wm_metrics)
            S = sum(wm_metrics.values())
            pri_sign = S/L

            if self.train_backdoor:
                backdoor_metrics = tester.test(self.wm_data, 'Old WM Accuracy')

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            if self.train_backdoor:
                for key in backdoor_metrics: metrics[f'backdoor_{key}'] = backdoor_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)
                best_ep = ep

            self.save_last_model()
            f = open(best_file,'a')
            print(str(wm_metrics) + '\n', file=f)
            print(str(metrics) + '\n', file=f)
            f.write('Bset ACC %s'%str(best_acc) + "\n")
            print('Private Sign Detction:',str(pri_sign) + '\n', file=f)
            f.write( "\n")
            f.write("best epoch: %s"%str(best_ep) + '\n')
            f.flush()
Ejemplo n.º 6
0
class ClassificationExperiment(Experiment):
    def __init__(self, args):
        super().__init__(args)

        self.in_channels = 1 if self.dataset == 'mnist' else 3
        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256,
            'imagenet1000': 1000
        }[self.dataset]

        self.train_data, self.valid_data = prepare_dataset(self.args)
        self.wm_data = None

        if self.use_trigger_as_passport:
            self.passport_data = prepare_wm('data/trigger_set/pics', crop=self.imgcrop)
        else:
            self.passport_data = self.valid_data

        if self.train_backdoor:
            self.wm_data = prepare_wm('data/trigger_set/pics', crop=self.imgcrop)

        self.construct_model()

        optimizer = optim.SGD(self.model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0001)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        self.trainer = Trainer(self.model, optimizer, scheduler, self.device)

        if self.is_tl:
            self.finetune_load()
        else:
            self.makedirs_or_load()

    def construct_model(self):
        print('Construct Model')

        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes,
                                                     norm_type=norm_type,
                                                     pretrained=pretrained_from_torch)
                else:
                    ResNetClass = ResNet18 if self.arch == 'resnet' else ResNet9
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNetClass(num_classes=self.num_classes,
                                                   norm_type=norm_type,
                                                   pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        def load_pretrained():
            if self.pretrained_path is not None:
                sd = torch.load(self.pretrained_path)
                model.load_state_dict(sd)

        if self.train_passport:
            passport_kwargs, plkeys = construct_passport_kwargs(self, True)
            self.passport_kwargs = passport_kwargs
            self.plkeys = plkeys
            self.is_baseline = False

            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetPassport(self.in_channels, self.num_classes, passport_kwargs)
            else:
                ResNetPassportClass = ResNet18Passport if self.arch == 'resnet' else ResNet9Passport
                model = ResNetPassportClass(num_classes=self.num_classes,
                                            passport_kwargs=passport_kwargs)
            self.model = model.to(self.device)

            setup_keys()
        else:  # train normally or train backdoor
            print('Loading arch: ' + self.arch)
            self.is_baseline = True

            if self.arch == 'alexnet':
                model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type)
            else:
                ResNetClass = ResNet18 if self.arch == 'resnet' else ResNet9
                model = ResNetClass(num_classes=self.num_classes, norm_type=self.norm_type)

            load_pretrained()
            self.model = model.to(self.device)

        pprint(self.model)

    def setup_keys(self, pretrained_model):
        if self.key_type != 'random':
            n = 1 if self.key_type == 'image' else 20  # any number will do

            key_x, x_inds = passport_generator.get_key(self.passport_data, n)
            key_x = key_x.to(self.device)
            key_y, y_inds = passport_generator.get_key(self.passport_data, n)
            key_y = key_y.to(self.device)

            passport_generator.set_key(pretrained_model, self.model,
                                       key_x, key_y)

    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        is_imagenet = self.num_classes == 1000

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256,
            'imagenet1000': 1000
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type,
                                     imagenet=is_imagenet)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type,
                                imagenet=is_imagenet)

        ##### load / reset weights of passport layers for clone model #####
        tl_model.to(self.device)
        if self.is_baseline:  # baseline
            load_normal_model_to_normal_model(self.arch, tl_model, self.model)
        else:
            load_passport_model_to_normal_model(self.arch, self.plkeys, self.model, tl_model)

        print(tl_model)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                if isinstance(tl_model.classifier, nn.Sequential):
                    tl_model.classifier[-1].reset_parameters()
                else:
                    tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        ##### training is on finetune model
        self.trainer = Trainer(tl_model,
                               optimizer,
                               scheduler,
                               self.device)

        ##### tester is on original model
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            ##### transfer learning on new tasks #####
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            if self.is_baseline:
                load_normal_model_to_normal_model(self.arch, self.model, tl_model)
            else:
                load_normal_model_to_passport_model(self.arch, self.plkeys, self.model, tl_model)

            tl_model.to(self.device)
            self.model.to(self.device)

            ##### check if using weight of finetuned model is still able to detect trigger set watermark #####
            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            ##### check if using weight of finetuend model is still able to extract signature correctly #####
            if not self.is_baseline and self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            ##### store results #####
            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)

            self.save_last_model()

    def training(self):
        best_acc = float('-inf')

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True

        if self.save_interval > 0:
            self.save_model('epoch-0.pth')

        print('Start training')

        for ep in range(1, self.epochs + 1):
            train_metrics = self.trainer.train(ep, self.train_data, self.wm_data)
            print(f'Sign Detection Accuracy: {train_metrics["sign_acc"] * 100:6.4f}')

            valid_metrics = self.trainer.test(self.valid_data, 'Testing Result')

            wm_metrics = {}

            if self.train_backdoor:
                wm_metrics = self.trainer.test(self.wm_data, 'WM Result')

            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'wm_{key}'] = wm_metrics[key]

            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')

            self.save_last_model()

    def evaluate(self):
        self.trainer.test(self.valid_data)
Ejemplo n.º 7
0
    def transfer_learning(self):
        if not self.is_tl:
            raise Exception('Please run with --transfer-learning')

        is_imagenet = self.num_classes == 1000

        self.num_classes = {
            'cifar10': 10,
            'cifar100': 100,
            'caltech-101': 101,
            'caltech-256': 256,
            'imagenet1000': 1000
        }[self.tl_dataset]

        ##### load clone model #####
        print('Loading clone model')
        if self.arch == 'alexnet':
            tl_model = AlexNetNormal(self.in_channels,
                                     self.num_classes,
                                     self.norm_type,
                                     imagenet=is_imagenet)
        else:
            tl_model = ResNet18(num_classes=self.num_classes,
                                norm_type=self.norm_type,
                                imagenet=is_imagenet)

        ##### load / reset weights of passport layers for clone model #####
        tl_model.to(self.device)
        if self.is_baseline:  # baseline
            load_normal_model_to_normal_model(self.arch, tl_model, self.model)
        else:
            load_passport_model_to_normal_model(self.arch, self.plkeys, self.model, tl_model)

        print(tl_model)
        print('Loaded clone model')

        ##### dataset is created at constructor #####

        ##### tl scheme setup #####
        if self.tl_scheme == 'rtal':
            # rtal = reset last layer + train all layer
            # ftal = train all layer
            try:
                if isinstance(tl_model.classifier, nn.Sequential):
                    tl_model.classifier[-1].reset_parameters()
                else:
                    tl_model.classifier.reset_parameters()
            except:
                tl_model.linear.reset_parameters()

        ##### optimizer setup #####
        optimizer = optim.SGD(tl_model.parameters(),
                              lr=self.lr,
                              momentum=0.9,
                              weight_decay=0.0005)

        if len(self.lr_config[self.lr_config['type']]) != 0:  # if no specify steps, then scheduler = None
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       self.lr_config[self.lr_config['type']],
                                                       self.lr_config['gamma'])
        else:
            scheduler = None

        ##### training is on finetune model
        self.trainer = Trainer(tl_model,
                               optimizer,
                               scheduler,
                               self.device)

        ##### tester is on original model
        tester = Tester(self.model,
                        self.device)
        tester_passport = TesterPrivate(self.model,
                                        self.device)

        history_file = os.path.join(self.logdir, 'history.csv')
        first = True
        best_acc = 0

        for ep in range(1, self.epochs + 1):
            ##### transfer learning on new tasks #####
            train_metrics = self.trainer.train(ep, self.train_data)
            valid_metrics = self.trainer.test(self.valid_data)

            ##### load transfer learning weights from clone model  #####
            if self.is_baseline:
                load_normal_model_to_normal_model(self.arch, self.model, tl_model)
            else:
                load_normal_model_to_passport_model(self.arch, self.plkeys, self.model, tl_model)

            tl_model.to(self.device)
            self.model.to(self.device)

            ##### check if using weight of finetuned model is still able to detect trigger set watermark #####
            wm_metrics = {}
            if self.train_backdoor:
                wm_metrics = tester.test(self.wm_data, 'WM Result')

            ##### check if using weight of finetuend model is still able to extract signature correctly #####
            if not self.is_baseline and self.train_passport:
                res = tester_passport.test_signature()
                for key in res: wm_metrics['passport_' + key] = res[key]

            ##### store results #####
            metrics = {}
            for key in train_metrics: metrics[f'train_{key}'] = train_metrics[key]
            for key in valid_metrics: metrics[f'valid_{key}'] = valid_metrics[key]
            for key in wm_metrics: metrics[f'old_wm_{key}'] = wm_metrics[key]
            self.append_history(history_file, metrics, first)
            first = False

            if self.save_interval and ep % self.save_interval == 0:
                self.save_model(f'epoch-{ep}.pth')
                self.save_model(f'tl-epoch-{ep}.pth', tl_model)

            if best_acc < metrics['valid_acc']:
                print(f'Found best at epoch {ep}\n')
                best_acc = metrics['valid_acc']
                self.save_model('best.pth')
                self.save_model('tl-best.pth', tl_model)

            self.save_last_model()
Ejemplo n.º 8
0
def main(args):

    # Directory of current experiment
    experiment_dir = 'experiments/dqn_lstm/test1'

    # Load configuration
    config = Config()

    config.env = args.env

    config.hyperparameters = {
        "learning_rate": 0.025,
        "batch_size": 32,
        "sequence_length": 1,
        "buffer_size": int(1e5),
        "update_every_n_steps": 1,
        "min_steps_before_learning": 1000,
        "epsilon_start": 1,
        "epsilon_end": 0.1,
        "epsilon_decay": 0.995,
        "discount_rate": 0.99,
        "tau": 0.01,
    }

    config.use_cuda = True

    config.number_of_episodes = 500
    config.steps_per_episode = 500
    config.previous_episode = 0
    config.total_steps = 160000
    config.pre_train_steps = 100
    config.learing_frequency = 1

    config.checkpoint = True
    config.checkpoint_interval = 1
    config.checkpoint_dir = experiment_dir + '/checkpoints'

    config.log_dir = experiment_dir + '/logs'

    config.model_dir = experiment_dir + '/model'

    # Initialize the environment
    env = gym.make('Urban-v0')
    config.state_dim = env.observation_space.shape
    config.action_dim = env.action_space.n

    # Initialize the agent
    agent = DDQNAgent(config)

    # Initialize spawner
    spawner = Spawner()

    if args.train:
        trainer = Trainer(env, agent, spawner, config)

        try:
            trainer.train()

        except KeyboardInterrupt:
            try:
                trainer.close()
                sys.exit(0)
            except SystemExit:
                trainer.close()
                os._exit(0)

    elif args.retrain:
        if args.checkpoint_file is None:
            print(
                ' ---------- Please mention checkoutpoint file name ----------'
            )
            return

        trainer = Trainer(env, agent, spawner, config)
        trainer.load_checkpoint(args.checkpoint_file)

        try:
            trainer.retrain()

        except KeyboardInterrupt:
            try:
                trainer.close()
                sys.exit(0)
            except SystemExit:
                trainer.close()
                os._exit(0)

    elif args.test:
        tester = Tester(episodes, steps)
        tester.load_checkpoint(args.checkpoint_file)

        try:
            tester.retrain()

        except KeyboardInterrupt:
            try:
                tester.close()
                sys.exit(0)
            except SystemExit:
                tester.close()
                os._exit(0)