def construct_model(self):
        def setup_keys():
            if self.key_type != 'random':
                if self.arch == 'alexnet':
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes,
                                                     self.norm_type)
                else:
                    pretrained_model = ResNet18(num_classes=self.num_classes,
                                                norm_type=self.norm_type)
                pretrained_model.load_state_dict(
                    torch.load(self.pretrained_path))
                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()
Beispiel #2
0
    def construct_model(self):
        def setup_keys():
            if self.key_type != 'random':
                if self.arch == 'alexnet':
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes)
                else:
                    pretrained_model = ResNet18(num_classes=self.num_classes,
                                                norm_type=self.norm_type)

                pretrained_model.load_state_dict(
                    torch.load(self.pretrained_path))
                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        def load_pretrained():
            if self.pretrained_path is not None:
                sd = torch.load(self.pretrained_path)
                model.load_state_dict(sd)

        if self.train_passport:
            passport_kwargs = construct_passport_kwargs(self)
            self.passport_kwargs = passport_kwargs

            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetPassport(self.in_channels, self.num_classes,
                                        passport_kwargs)
            else:
                model = ResNet18Passport(num_classes=self.num_classes,
                                         passport_kwargs=passport_kwargs)
            self.model = model.to(self.device)

            setup_keys()
        else:  # train normally or train backdoor
            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetNormal(self.in_channels, self.num_classes,
                                      self.norm_type)
            else:
                model = ResNet18(num_classes=self.num_classes,
                                 norm_type=self.norm_type)

            load_pretrained()
            self.model = model.to(self.device)

        pprint(self.model)
    def construct_model(self):
        print('Construct Model')

        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(
                        self.in_channels,
                        self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)
                else:
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNet18(
                        num_classes=self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(
                        torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        print('Loading arch: ' + self.arch)
        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()

        pprint(self.model)
Beispiel #4
0
    def construct_model(self):
        print('Construct Model')

        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes,
                                                     norm_type=norm_type,
                                                     pretrained=pretrained_from_torch)
                else:
                    ResNetClass = ResNet18 if self.arch == 'resnet' else ResNet9
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNetClass(num_classes=self.num_classes,
                                                   norm_type=norm_type,
                                                   pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        def load_pretrained():
            if self.pretrained_path is not None:
                sd = torch.load(self.pretrained_path)
                model.load_state_dict(sd)

        if self.train_passport:
            passport_kwargs, plkeys = construct_passport_kwargs(self, True)
            self.passport_kwargs = passport_kwargs
            self.plkeys = plkeys
            self.is_baseline = False

            print('Loading arch: ' + self.arch)
            if self.arch == 'alexnet':
                model = AlexNetPassport(self.in_channels, self.num_classes, passport_kwargs)
            else:
                ResNetPassportClass = ResNet18Passport if self.arch == 'resnet' else ResNet9Passport
                model = ResNetPassportClass(num_classes=self.num_classes,
                                            passport_kwargs=passport_kwargs)
            self.model = model.to(self.device)

            setup_keys()
        else:  # train normally or train backdoor
            print('Loading arch: ' + self.arch)
            self.is_baseline = True

            if self.arch == 'alexnet':
                model = AlexNetNormal(self.in_channels, self.num_classes, self.norm_type)
            else:
                ResNetClass = ResNet18 if self.arch == 'resnet' else ResNet9
                model = ResNetClass(num_classes=self.num_classes, norm_type=self.norm_type)

            load_pretrained()
            self.model = model.to(self.device)

        pprint(self.model)