def construct_model(self):
        def setup_keys():
            if self.key_type != 'random':
                if self.arch == 'alexnet':
                    pretrained_model = AlexNetNormal(self.in_channels,
                                                     self.num_classes,
                                                     self.norm_type)
                else:
                    pretrained_model = ResNet18(num_classes=self.num_classes,
                                                norm_type=self.norm_type)
                pretrained_model.load_state_dict(
                    torch.load(self.pretrained_path))
                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()
Пример #2
0
    def construct_model(self):
        print('Construct Model')

        def setup_keys():
            if self.key_type != 'random':
                pretrained_from_torch = self.pretrained_path is None
                if self.arch == 'alexnet':
                    norm_type = 'none' if pretrained_from_torch else self.norm_type
                    pretrained_model = AlexNetNormal(
                        self.in_channels,
                        self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)
                else:
                    norm_type = 'bn' if pretrained_from_torch else self.norm_type
                    pretrained_model = ResNet18(
                        num_classes=self.num_classes,
                        norm_type=norm_type,
                        pretrained=pretrained_from_torch)

                if not pretrained_from_torch:
                    print('Loading pretrained from self-trained model')
                    pretrained_model.load_state_dict(
                        torch.load(self.pretrained_path))
                else:
                    print('Loading pretrained from torch-pretrained model')

                pretrained_model = pretrained_model.to(self.device)
                self.setup_keys(pretrained_model)

        passport_kwargs = construct_passport_kwargs(self)
        self.passport_kwargs = passport_kwargs

        print('Loading arch: ' + self.arch)
        if self.arch == 'alexnet':
            model = AlexNetPassportPrivate(self.in_channels, self.num_classes,
                                           passport_kwargs)
        else:
            model = ResNet18Private(num_classes=self.num_classes,
                                    passport_kwargs=passport_kwargs)

        self.model = model.to(self.device)

        setup_keys()

        pprint(self.model)