def __init__(self, opt, centers=None, *args, **kwargs): super(k3LeNet5, self).__init__(opt, *args, **kwargs) logger.warning('k3LeNet5 does not have a purely linear layer between the ' + 'kernelized components and the rest of the model, which may ' + 'worsen performance') if centers is not None: # centers is a tuple of (input, target) centers1 = utils.supervised_sample(centers[0], centers[1], opt.n_centers1).clone().detach() centers2 = utils.supervised_sample(centers[0], centers[1], opt.n_centers2).clone().detach() centers3 = utils.supervised_sample(centers[0], centers[1], opt.n_centers3).clone().detach() else: centers1, centers2, centers3 = None, None, None self.fc1 = kLinear(in_features=self.feat_len, out_features=120, kernel=self.kernel, evaluation=self.evaluation, centers=centers1, sigma=opt.sigma1) self.fc2 = kLinear(in_features=120, out_features=84, kernel=self.kernel, evaluation=self.evaluation, centers=centers2, sigma=opt.sigma2) self.fc3 = kLinear(in_features=84, out_features=10, kernel=self.kernel, evaluation=self.evaluation, centers=centers3, sigma=opt.sigma3) if opt.memory_efficient: self.fc1 = utils.to_committee(self.fc1, opt.expert_size) self.fc2 = utils.to_committee(self.fc2, opt.expert_size) self.fc3 = utils.to_committee(self.fc3, opt.expert_size)
def setUp(self): # allow tests to be run on GPU if possible self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.input = torch.randn(100, 15).to(self.device) self.target = torch.randint(0, 10, (100,)).to(self.device) self.klinear = kLinear(out_features=10, kernel='gaussian', evaluation='indirect', centers=self.input).to(self.device) self.klinear_committee = utils.to_committee(self.klinear, 30) # toy two-layer model self.toy_input = torch.tensor([[1., 2], [3, 4]]).to(self.device) self.toy_target = torch.tensor([0, 1]).to(self.device) self.toy_klinear1 = kLinear( out_features=2, kernel='gaussian', evaluation='indirect', centers=self.toy_input, sigma=3 ).to(self.device) self.toy_klinear1.linear.weight.data = torch.tensor([[.1, .2], [.5, .7]]).to(self.device) self.toy_klinear1.linear.bias.data = torch.tensor([0., 0]).to(self.device) self.toy_klinear2 = kLinear( out_features=2, kernel='gaussian', evaluation='indirect', centers=self.toy_klinear1(self.toy_input).detach(), sigma=2 ).to(self.device) # stop grad flow through layer2 centers to layer1 self.toy_klinear2.linear.weight.data = torch.tensor([[1.2, .3], [.2, 1.7]]).to(self.device) self.toy_klinear2.linear.bias.data = torch.tensor([.1, .2]).to(self.device) self.toy_net = torch.nn.Sequential( self.toy_klinear1, self.toy_klinear2 )
def __init__(self, opt, centers=None, *args, **kwargs): super(k2LeNet5, self).__init__(opt, *args, **kwargs) self.fc1 = torch.nn.Linear(self.feat_len, 120) if centers is not None: # centers is a tuple of (input, target) centers2 = utils.supervised_sample( centers[0], centers[1], opt.n_centers2).clone().detach() centers3 = utils.supervised_sample( centers[0], centers[1], opt.n_centers3).clone().detach() else: centers2, centers3 = None, None self.fc2 = kLinear(in_features=120, out_features=84, kernel=self.kernel, evaluation=self.evaluation, centers=centers2, sigma=opt.sigma2) self.fc3 = kLinear(in_features=84, out_features=10, kernel=self.kernel, evaluation=self.evaluation, centers=centers3, sigma=opt.sigma3) if opt.memory_efficient: self.fc2 = utils.to_committee(self.fc2, opt.expert_size) self.fc3 = utils.to_committee(self.fc3, opt.expert_size) self.print_network(self)
def __init__(self, opt, centers, block, num_blocks, num_classes=10): super(kResNet, self).__init__(block, num_blocks, num_classes, skip_layer=['layer5', 'fc']) self.opt = opt if opt.activation == 'tanh': self.kernel = 'nn_tanh' self.evaluation = 'direct' elif opt.activation == 'sigmoid': self.kernel = 'nn_sigmoid' self.evaluation = 'direct' elif opt.activation == 'relu': self.kernel = 'nn_relu' self.evaluation = 'direct' elif opt.activation == 'reapen': self.kernel = 'nn_reapen' self.evaluation = 'direct' elif opt.activation == 'gaussian': self.kernel = 'gaussian' self.evaluation = 'indirect' else: raise NotImplementedError() self.layer5 = self._make_layer_no_output_relu(block, 512, num_blocks[3], stride=2) if centers is not None: # centers is a tuple of (input, target) centers = utils.supervised_sample(centers[0], centers[1], opt.n_centers).clone().detach() else: centers = None fc = kLinear(in_features=512*block.expansion, out_features=num_classes, kernel=self.kernel, evaluation=self.evaluation, centers=centers, sigma=opt.sigma) if opt.memory_efficient: fc = utils.to_committee(fc, opt.expert_size) self.fc = fc
def __init__(self, opt, *args, **kwargs): super(kMLP, self).__init__() if opt.activation == 'tanh': self.kernel = 'nn_tanh' self.evaluation = 'direct' elif opt.activation == 'sigmoid': self.kernel = 'nn_sigmoid' self.evaluation = 'direct' elif opt.activation == 'relu': self.kernel = 'nn_relu' self.evaluation = 'direct' else: raise NotImplementedError() self.flatten = Flatten() arch_list = [int(_) for _ in opt.arch.split('_')] for i in range(len(arch_list) - 1): if i == 0: setattr( self, f'layer_{i+1}', torch.nn.Linear(in_features=arch_list[i], out_features=arch_list[i + 1])) else: setattr( self, f'layer_{i+1}', kLinear(in_features=arch_list[i], out_features=arch_list[i + 1], kernel=self.kernel, evaluation=self.evaluation)) self.opt = opt self.arch_list = arch_list self.n_layers = len(arch_list) - 1 self.print_network(self)
def to_committee(model, expert_size): """ Convert a kLinear model into a committee of experts (a kLinearCommittee model) with each expert except possibly the last being of size expert_size. The new model is numerically equivalent to the original. If needed, call this function RIGHT AFTER model initialization. This conversion preserves the device allocation of the original model, i.e., if model is on GPU, the returned committee will also be on GPU. """ logger.info('Converting {} to committee w/ expert size {}...'.format( model.__class__.__name__, expert_size)) if not isinstance(model, kLinear): raise TypeError('Expecting the model to be of ' + 'kLinear type, got {} instead.'.format(type(model))) if not hasattr(model, 'centers'): logger.warning('The given model does not have centers, ' + 'in which case the conversion to committee ' + 'was not performed. The original model ' + 'was returned instead.') return model centers = model.centers committee = kLinearCommittee() i = 0 while i * expert_size < len(centers): bias = True if model.linear.bias is not None and i == 0 else False expert = kLinear( out_features=model.out_features, in_features=model.in_features, kernel=model.kernel, bias=bias, evaluation=model.evaluation, centers=centers[i * expert_size: (i + 1) * expert_size].clone().detach(), trainable_centers=getattr(model, 'trainable_centers', False), sigma=model.phi.k_params['sigma'] ) expert.linear.weight.data = \ model.linear.weight[:, i * expert_size: (i + 1) * expert_size].clone().detach() if bias: expert.linear.bias.data = model.linear.bias.clone().detach() committee.add_expert(expert) i += 1 return committee
def test_nn_sigmoid_kernel_consistency_with_native_pytorch(self, ): input = torch.randn(30, 784).to(self.device) klinear1 = kLinear( in_features=784, out_features=10, kernel='nn_sigmoid', evaluation='direct', ).to(self.device) linear = torch.nn.Linear(784, 10).to(self.device) linear.weight.data = klinear1.linear.weight.data linear.bias.data = klinear1.linear.bias.data torch_model = torch.nn.Sequential( torch.nn.Sigmoid(), Normalize(), linear ) self.assertTrue(np.allclose( klinear1(input).detach().cpu().numpy(), torch_model(input).detach().cpu().numpy() ))
def test_nn_reapen_kernel_consistency_with_native_pytorch(self, ): input = torch.randn(30, 3, 32, 32).to(self.device) klinear1 = kLinear( in_features=192, out_features=10, kernel='nn_reapen', evaluation='direct', ).to(self.device) linear = torch.nn.Linear(192, 10).to(self.device) linear.weight.data = klinear1.linear.weight.data linear.bias.data = klinear1.linear.bias.data torch_model = torch.nn.Sequential( torch.nn.ReLU(), torch.nn.AvgPool2d(4), Flatten(), Normalize(), linear ) self.assertTrue(torch.allclose( klinear1(input), torch_model(input), ))