def _set_model(self): """ get model """ if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.net_type == "preresnet": self.ori_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) elif self.settings.dataset in ["imagenet"]: if self.settings.net_type == "resnet": self.ori_model = md.ResNet(depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.ResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) else: assert False, "unsupported data set: {}".format( self.settings.dataset)
def get_model(dataset, net_type, depth, n_classes): """ Available model cifar: preresnet vgg imagenet: resnet """ if dataset in ["cifar10", "cifar100"]: test_input = torch.randn(1, 3, 32, 32).cuda() if net_type == "preresnet": model = md.PreResNet(depth=depth, num_classes=n_classes) else: assert False, "use {} data while network is {}".format( dataset, net_type) elif dataset in ["imagenet", "sub_imagenet"]: test_input = torch.randn(1, 3, 224, 224).cuda() if net_type == "resnet": model = md.ResNet(depth=depth, num_classes=n_classes) else: assert False, "use {} data while network is {}".format( dataset, net_type) else: assert False, "unsupported data set: {}".format(dataset) return model, test_input
def _set_model(self): """ get model """ if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.net_type == "preresnet": self.pruned_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) elif self.settings.dataset in ["imagenet", "imagenet_mio"]: if self.settings.net_type == "resnet": self.pruned_model = md.ResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) else: assert False, "unsupported data set: {}".format( self.settings.dataset) # replace the conv layer in resnet with mask_conv if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): # replace conv2 temp_conv = MaskConv2d( in_channels=module.conv2.in_channels, out_channels=module.conv2.out_channels, kernel_size=module.conv2.kernel_size, stride=module.conv2.stride, padding=module.conv2.padding, bias=(module.conv2.bias is not None)) temp_conv.weight.data.copy_(module.conv2.weight.data) if module.conv2.bias is not None: temp_conv.bias.data.copy_(module.conv2.bias.data) module.conv2 = temp_conv if isinstance(module, (Bottleneck)): # replace conv3 temp_conv = MaskConv2d( in_channels=module.conv3.in_channels, out_channels=module.conv3.out_channels, kernel_size=module.conv3.kernel_size, stride=module.conv3.stride, padding=module.conv3.padding, bias=(module.conv3.bias is not None)) temp_conv.weight.data.copy_(module.conv3.weight.data) if module.conv3.bias is not None: temp_conv.bias.data.copy_(module.conv3.bias.data) module.conv3 = temp_conv