def __init__( self, n_head=4, hidden=32, weight_sparsity=0.9, attn_pdrop=0, resid_pdrop=0, block_size=512, ): super().__init__() assert hidden % n_head == 0 # key, query, value projections for all heads self.key = SparseWeights(nn.Linear(hidden, hidden), sparsity=weight_sparsity) self.query = SparseWeights(nn.Linear(hidden, hidden), sparsity=weight_sparsity) self.value = SparseWeights(nn.Linear(hidden, hidden), sparsity=weight_sparsity) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.resid_drop = nn.Dropout(resid_pdrop) # output projection self.proj = nn.Linear(hidden, hidden) # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), ) self.n_head = n_head
def setUp(self): self.model = nn.Sequential( SparseWeights(nn.Linear(100, 100), sparsity=0.25), nn.ReLU(), SparseWeights(nn.Linear(100, 100), sparsity=0.75), nn.ReLU(), nn.Linear(100, 10), ) self.model.apply(set_weights_to_one) self.model.apply(rezero_weights) self.batch = torch.rand(10, 100)
def sparsify_model(self): """ Sparsify all linear layers in encoder as well as the word embedding layer. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Use `getattr` here for backwards compatibility for configs without this param. sparsify_all_embeddings = getattr(self.config, "sparsify_all_embeddings", False) def get_sparsity(name): if isinstance(sparsity, dict): if name in sparsity: return sparsity[name] else: raise KeyError( f"Layer {name} not included in sparsity dict.") else: return sparsity # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): layer_sparsity = get_sparsity("bert.encoder." + name) sparse_module = SparseWeights( module, sparsity=layer_sparsity, allow_extremes=True # this allows the model to start fully dense ) set_module_attr(self.encoder, name, sparse_module.to(device)) # Replace the embedding layers in a similar fashion. if sparsify_all_embeddings: embeddings = [ "word_embeddings", "position_embeddings", "token_type_embeddings" ] else: embeddings = ["word_embeddings"] for embedding_name in embeddings: dense_module = getattr(self.embeddings, embedding_name) layer_sparsity = get_sparsity(f"bert.embeddings.{embedding_name}") sparse_module = SparseEmbeddings(dense_module, sparsity=layer_sparsity) setattr(self.embeddings, embedding_name, sparse_module.to(device))
def __init__(self, input_shape=INPUT_SHAPE, hidden_size=4, output_size=4): super().__init__() input_size = np.prod(input_shape) self.flatten = torch.nn.Flatten() self.lin1 = SparseWeights( torch.nn.Linear(input_size, hidden_size, bias=None), sparsity=0.5 ) self.lin2 = SparseWeights( torch.nn.Linear(hidden_size, hidden_size, bias=None), sparsity=0.5 ) self.lin3 = SparseWeights( torch.nn.Linear(hidden_size, output_size, bias=None), sparsity=0.5 )
def __init__(self, input_size=28 * 28, n_hidden_units=1000, n_classes=10, is_sparse=False, sparsity=(0.75, 0.85), percent_on=0.1): """ Initialize a 2-layer MLP :param input_size: number of input features to the MLP :type input_size: int :param n_hidden_units: number of units in each of the two hidden layers :type n_hidden_units: int :param n_classes: number of output units :type n_classes: int :param is_sparse: whether or not to initialize the sparse network instead of a dense one :type is_sparse: bool :param sparsity: a 2-element list/tuple specifying the sparsity in each of the hidden layers :type sparsity: list/tuple of float :param percent_on: number of active units in the K-Winners layer (only applies to sparse networks) :type percent_on: float """ super().__init__() self.is_sparse = is_sparse self.flatten = Flatten() self.n_classes = n_classes self.fc1 = torch.nn.Linear(input_size, n_hidden_units) self.fc2 = torch.nn.Linear(n_hidden_units, n_hidden_units) self.fc3 = torch.nn.Linear(n_hidden_units, n_classes) if is_sparse: self.fc1_sparsity, self.fc2_sparsity = sparsity self.percent_on = percent_on self.fc1 = SparseWeights(self.fc1, sparsity=self.fc1_sparsity) self.kw1 = KWinners(n=n_hidden_units, percent_on=percent_on, boost_strength=0.0) self.fc2 = SparseWeights(self.fc2, sparsity=self.fc2_sparsity) self.kw2 = KWinners(n=n_hidden_units, percent_on=percent_on, boost_strength=0.0)
def _create_representation_module(self, module_type, dims): if module_type is None: return None representation_module = nn.Sequential() inp_dim = self.input_size for i in range(len(dims)): output_dim = dims[i] layer = SparseWeights(torch.nn.Linear(inp_dim, output_dim, bias=True), sparsity=self.weight_sparsity, allow_extremes=True) # network input is dense (no sparsity constraints) DendriticMLP._init_sparse_weights(layer, 0.0) if module_type == "relu": nonlinearity = nn.ReLU() else: raise NotImplementedError representation_module.add_module("linear_layer_{}".format(i), layer) representation_module.add_module("nonlinearity_{}".format(i), nonlinearity) inp_dim = output_dim self.representation_dim = inp_dim return representation_module
def test_rezero_1d(self): in_features, out_features = 784, 10 for sparsity in [0.1, 0.5, 0.9]: linear = torch.nn.Linear(in_features=in_features, out_features=out_features) sparse = SparseWeights(linear, sparsity=sparsity) # Ensure weights are not sparse sparse.module.weight.data.fill_(1.0) # Rezero, verify the weights become sparse sparse.rezero_weights() nonzeros = torch.nonzero(sparse.module.weight, as_tuple=True)[0] counts = torch.unique(nonzeros, return_counts=True)[1] expected = [round(in_features * (1.0 - sparsity))] * out_features self.assertSequenceEqual(counts.numpy().tolist(), expected)
def _create_preprocess_module(self, module_type, preprocess_output_dim, kw_percent_on): if module_type is None: return None preprocess_module = nn.Sequential() linear_layer = SparseWeights(torch.nn.Linear( self.context_representation_dim + self.representation_dim, preprocess_output_dim, bias=True), sparsity=self.weight_sparsity, allow_extremes=True) DendriticMLP._init_sparse_weights(linear_layer, 0.0) if module_type == "relu": nonlinearity = nn.ReLU() else: nonlinearity = KWinners(n=preprocess_output_dim, percent_on=kw_percent_on, k_inference_factor=1.0, boost_strength=0.0, boost_strength_factor=0.0) preprocess_module.add_module("linear_layer", linear_layer) preprocess_module.add_module("nonlinearity", nonlinearity) self.context_representation_dim = preprocess_output_dim return preprocess_module
def __init__(self, num_classes, input_shape): super().__init__() in_features = np.prod(input_shape) self.flatten = torch.nn.Flatten() self.classifier = SparseWeights( torch.nn.Linear(in_features, num_classes, bias=False), sparsity=0.5, )
def __init__(self, input_size, hidden, sparsity, percent_on, boost_strength): super().__init__() self.sparse_linear = SparseWeights(nn.Linear(input_size, hidden), sparsity=sparsity) self.kw = KWinners(n=hidden, percent_on=percent_on, boost_strength=boost_strength)
def __init__(self, num_classes, input_shape): super().__init__() in_features = np.prod(input_shape) self.flatten = torch.nn.Flatten() self.kwinners = KWinners(in_features, percent_on=0.1) self.classifier = SparseWeights( nn.Linear(in_features, num_classes, bias=False), sparsity=0.9 )
def __init__(self, cnn_out_channels=(64, 64), cnn_percent_on=(0.095, 0.125), linear_units=1000, linear_percent_on=0.1, linear_weight_sparsity=0.4, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000): super(GSCSparseCNN, self).__init__( OrderedDict([ # First Sparse CNN layer ("cnn1", nn.Conv2d(1, cnn_out_channels[0], 5)), ("cnn1_batchnorm", nn.BatchNorm2d(cnn_out_channels[0], affine=False)), ("cnn1_maxpool", nn.MaxPool2d(2)), ("cnn1_kwinner", KWinners2d(channels=cnn_out_channels[0], percent_on=cnn_percent_on[0], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), # Second Sparse CNN layer ("cnn2", nn.Conv2d(cnn_out_channels[0], cnn_out_channels[1], 5)), ("cnn2_batchnorm", nn.BatchNorm2d(cnn_out_channels[1], affine=False)), ("cnn2_maxpool", nn.MaxPool2d(2)), ("cnn2_kwinner", KWinners2d(channels=cnn_out_channels[1], percent_on=cnn_percent_on[1], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), ("flatten", Flatten()), # Sparse Linear layer ("linear", SparseWeights(nn.Linear(25 * cnn_out_channels[1], linear_units), weight_sparsity=linear_weight_sparsity)), ("linear_bn", nn.BatchNorm1d(linear_units, affine=False)), ("linear_kwinner", KWinners(n=linear_units, percent_on=linear_percent_on, k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), # Classifier ("output", nn.Linear(linear_units, 12)), ("softmax", nn.LogSoftmax(dim=1)) ]))
def test_rezero_after_forward_1d(self): in_features, out_features = 784, 10 for percent_on in [0.1, 0.5, 0.9]: linear = torch.nn.Linear(in_features=in_features, out_features=out_features) sparse = SparseWeights(linear, percent_on) # Ensure weights are not sparse sparse.module.weight.data.fill_(1.0) sparse.train() x = torch.ones((1,) + (in_features,)) sparse(x) # When training, the forward function should set weights back to zero. nonzeros = torch.nonzero(sparse.module.weight, as_tuple=True)[0] counts = torch.unique(nonzeros, return_counts=True)[1] expected = [round(in_features * percent_on)] * out_features self.assertSequenceEqual(counts.numpy().tolist(), expected)
def sparsify_model(self): """ Sparsify all non-attention linear layers in encoder. """ encoder = self.encoder num_sparse_layers = self.config.num_sparse_layers sparsity = self.config.sparsity device = self.device for idx in range(num_sparse_layers): intermediate_layer = encoder.layer[idx].intermediate.dense encoder.layer[idx].intermediate.dense = \ SparseWeights(intermediate_layer, sparsity=sparsity).to(device) output_layer = encoder.layer[idx].output.dense encoder.layer[idx].output.dense = \ SparseWeights(output_layer, sparsity=sparsity).to(device)
def add_sparse_linear_layer( network, suffix, input_size, linear_n, dropout, use_batch_norm, weight_sparsity, percent_on, k_inference_factor, boost_strength, boost_strength_factor, ): """Add sparse linear layer to network. :param network: The network to add the sparse layer to :param suffix: Layer suffix. Used to name its components :param input_size: Input size :param linear_n: Number of units :param dropout: dropout value :param use_batch_norm: whether or not to use batch norm :param weight_sparsity: Pct of weights that are allowed to be non-zero :param percent_on: Pct of ON (non-zero) units :param k_inference_factor: During inference we increase percent_on by this factor :param boost_strength: boost strength (0.0 implies no boosting) :param boost_strength_factor: boost strength is multiplied by this factor after each epoch """ linear = nn.Linear(input_size, linear_n) if 0 < weight_sparsity < 1.0: network.add_module( "linear{}".format(suffix), SparseWeights(linear, weight_sparsity) ) else: network.add_module("linear{}".format(suffix), linear) if use_batch_norm: network.add_module("linear_bn", nn.BatchNorm1d(linear_n, affine=False)) if dropout > 0.0: network.add_module("linear{}_dropout".format(suffix), nn.Dropout(dropout)) if 0 < percent_on < 1.0: network.add_module( "linear{}_kwinners".format(suffix), KWinners( n=linear_n, percent_on=percent_on, k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, ), ) else: network.add_module("linear{}_relu".format(suffix), nn.ReLU())
def __init__(self, input_size, output_size, hidden_size, num_segments, dim_context, sparsity, kw=False, relu=False, dendritic_layer_class=AbsoluteMaxGatingDendriticLayer): # The nonlinearity can either be k-Winners or ReLU, but not both assert not (kw and relu) super().__init__() self.num_segments = num_segments self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.dim_context = dim_context self.kw = kw self.relu = relu # Forward layers & k-winners self.dend1 = dendritic_layer_class(module=nn.Linear( input_size, hidden_size), num_segments=num_segments, dim_context=dim_context, module_sparsity=sparsity, dendrite_sparsity=sparsity) self.dend2 = dendritic_layer_class(module=nn.Linear( hidden_size, hidden_size), num_segments=num_segments, dim_context=dim_context, module_sparsity=sparsity, dendrite_sparsity=sparsity) if kw: self.kw1 = KWinners(n=hidden_size, percent_on=0.05, k_inference_factor=1.0, boost_strength=0.0, boost_strength_factor=0.0) self.kw2 = KWinners(n=hidden_size, percent_on=0.05, k_inference_factor=1.0, boost_strength=0.0, boost_strength_factor=0.0) if relu: self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() # Final classifier layer self.classifier = SparseWeights(nn.Linear(hidden_size, output_size), sparsity=sparsity)
def __init__( self, in_dim, n_dendrites, threshold=2, weight_sparsity=0.2, ): super(DendriteInput, self).__init__() self.threshold = threshold linear = nn.Linear(in_dim, n_dendrites) if weight_sparsity < 1: self.linear = SparseWeights(linear, weight_sparsity) else: self.linear = linear
def __init__(self, num_classes, input_shape): super().__init__() in_features = np.prod(input_shape) self.dendritic_gate = DendriticAbsoluteMaxGate1d() self.flatten = torch.nn.Flatten() self.kwinners = KWinners(n=16, percent_on=0.75, k_inference_factor=1) self.classifier = SparseWeights( torch.nn.Linear(in_features, num_classes, bias=False), sparsity=0.5, )
def __init__(self, cnn_out_channels=(32, 64), cnn_percent_on=(0.087, 0.293), linear_units=700, linear_percent_on=0.143, linear_weight_sparsity=0.3, boost_strength=1.5, boost_strength_factor=0.85, k_inference_factor=1.5, duty_cycle_period=1000): super(MNISTSparseCNN, self).__init__( OrderedDict([ # First Sparse CNN layer ("cnn1", nn.Conv2d(1, cnn_out_channels[0], 5)), ("cnn1_maxpool", nn.MaxPool2d(2)), ("cnn1_kwinner", KWinners2d(channels=cnn_out_channels[0], percent_on=cnn_percent_on[0], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), # Second Sparse CNN layer ("cnn2", nn.Conv2d(cnn_out_channels[0], cnn_out_channels[1], 5)), ("cnn2_maxpool", nn.MaxPool2d(2)), ("cnn2_kwinner", KWinners2d(channels=cnn_out_channels[1], percent_on=cnn_percent_on[1], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), ("flatten", Flatten()), # Sparse Linear layer ("linear", SparseWeights(nn.Linear(16 * cnn_out_channels[1], linear_units), weight_sparsity=linear_weight_sparsity)), ("linear_kwinner", KWinners(n=linear_units, percent_on=linear_percent_on, k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)), # Classifier ("output", nn.Linear(linear_units, 10)), ("softmax", nn.LogSoftmax(dim=1)) ]))
def _sparsify_linear(parent, linear_names, weight_sparsity): """Enforce weight sparsity on the given linear modules during training. :param parent: Parent Layer containing the Linear modules to sparsify :param linear_names: List of Linear module names to sparsify :param weight_sparsity: Percent of weights that are allowed to be non-zero """ for i, name in enumerate(linear_names): if weight_sparsity[i] >= 1.0: continue module = parent.__getattr__(name) parent.__setattr__(name, SparseWeights(module, weight_sparsity[i]))
def test_sparse_weights_1d(self): in_features, out_features = 784, 10 with torch.no_grad(): for percent_on in [0.1, 0.5, 0.9]: linear = torch.nn.Linear(in_features=in_features, out_features=out_features) sparse = SparseWeights(linear, percent_on) nonzeros = torch.nonzero(sparse.module.weight, as_tuple=True)[0] counts = torch.unique(nonzeros, return_counts=True)[1] # Expected non-zeros per output feature expected = [round(in_features * percent_on)] * out_features self.assertSequenceEqual(counts.numpy().tolist(), expected)
def __init__( self, dpc=3, cnn_w_sparsity=0.05, linear_w_sparsity=0.5, cat_w_sparsity=0.01, n_classes=4, ): super(ToyNetwork, self).__init__() conv_channels = 128 self.n_classes = n_classes self.conv1 = SparseWeights2d( nn.Conv2d( in_channels=1, out_channels=conv_channels, kernel_size=10, padding=0, stride=1, ), cnn_w_sparsity, ) self.kwin1 = KWinners2d(conv_channels, percent_on=0.1) self.bn = nn.BatchNorm2d(conv_channels, affine=False) self.mp1 = nn.MaxPool2d(kernel_size=2) self.flatten = Flatten() self.d1 = DendriteLayer( in_dim=int(conv_channels / 64) * 7744, out_dim=1000, dendrites_per_neuron=dpc, ) self.linear = SparseWeights(nn.Linear(1000, n_classes + 1), linear_w_sparsity) self.cat = SparseWeights(nn.Linear(n_classes + 1, 1000 * dpc), cat_w_sparsity)
def sparsify_model(self): """ Sparsify all linear layers in encoder. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): sparse_module = SparseWeights(module, sparsity=sparsity).to(device) set_module_attr(self.encoder, name, sparse_module)
def sparse_linear(in_features, out_features, bias=True, density=1.0): """ Get a nn.Linear, possibly wrapped in a SparseWeights :param density: Either a density or a function that returns a density. :type density: float or function(in_features, out_features) """ layer = nn.Linear(in_features, out_features, bias=bias) if callable(density): density = density(in_features, out_features) if density < 1.0: layer = SparseWeights(layer, weight_sparsity=density) return layer
def __init__( self, input_size, output_size, kw_percent_on=0.05, boost_strength=0.0, weight_sparsity=0.95, duty_cycle_period=1000, ): super().__init__() self.linear = SparseWeights(nn.Linear(input_size, output_size), sparsity=weight_sparsity, allow_extremes=True) self.kw = KWinners(n=output_size, percent_on=kw_percent_on, boost_strength=boost_strength, duty_cycle_period=duty_cycle_period)
def sparsify_model(self): """ Sparsify all linear layers in encoder as well as the word embedding layer. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): sparse_module = SparseWeights(module, sparsity=sparsity).to(device) set_module_attr(self.encoder, name, sparse_module) # Replace the embedding layer in a similar fashion. dense_embeddings = self.embeddings.word_embeddings sparse_embeddings = SparseEmbeddings(dense_embeddings, sparsity=sparsity) self.embeddings.word_embeddings = sparse_embeddings
def __init__(self, dim_in, dim_out, k, device=None, sparsity=0.7): """ :param dim_in: the number of dimensions in the input :type dim_in: int :param dim_out: the number of dimensions in the sparse linear output :type dim_out: int :param k: the number of unique random binary vectors that can "route" the sparse linear output :param device: device to use ('cpu' or 'cuda') :type device: :class:`torch.device` :type k: int :param sparsity: the sparsity in the SparseWeights layer (see nupic.torch.modules.SparseWeights for more details) :type sparsity: float """ super().__init__() self.sparse_weights = SparseWeights(torch.nn.Linear( in_features=dim_in, out_features=dim_out, bias=False), sparsity=sparsity) self.output_masks = generate_random_binary_vectors(k, dim_out) self.device = device if device is not None else torch.device("cpu")
def _create_vgg_model(self): """ block_sizes = [1,1,1] - number of CNN layers in each block cnn_out_channels = [c1, c2, c3] - # out_channels in each layer of this block cnn_kernel_size = [k1, k2, k3] - kernel_size in each layer of this block cnn_weight_sparsity = [w1, w2, w3] - weight sparsity of each layer of this block cnn_percent_on = [p1, p2, p3] - percent_on in each layer of this block """ # Here we require exactly 3 blocks # assert(len(self.block_sizes) == 3) # Create simple CNN model, with options for sparsity self.model = nn.Sequential() in_channels = 3 output_size = 32 * 32 output_units = output_size * in_channels for ly, block_size in enumerate(self.block_sizes): for b in range(block_size): self._add_cnn_layer( index_str=str(ly) + "_" + str(b), in_channels=in_channels, out_channels=self.cnn_out_channels[ly], kernel_size=self.cnn_kernel_sizes[ly], percent_on=self.cnn_percent_on[ly], weight_sparsity=self.cnn_weight_sparsity[ly], add_pooling=b == block_size - 1, ) in_channels = self.cnn_out_channels[ly] output_size = int(output_size / 4) output_units = output_size * in_channels # Flatten CNN output before passing to linear layer self.model.add_module("flatten", Flatten()) # Linear layer input_size = output_units for ly, linear_n in enumerate(self.linear_n): linear = nn.Linear(input_size, linear_n) if self.linear_weight_sparsity[ly] < 1.0: self.model.add_module( "linear_" + str(ly), SparseWeights(linear, self.linear_weight_sparsity[ly]), ) else: self.model.add_module("linear_" + str(ly), linear) if self.linear_percent_on[ly] < 1.0: self.model.add_module( "kwinners_linear_" + str(ly), KWinners( n=linear_n, percent_on=self.linear_percent_on[ly], k_inference_factor=self.k_inference_factor, boost_strength=self.boost_strength, boost_strength_factor=self.boost_strength_factor, ), ) else: self.model.add_module("Linear_ReLU_" + str(ly), nn.ReLU()) input_size = self.linear_n[ly] # Output layer self.model.add_module("output", nn.Linear(input_size, self.output_size)) print(self.model) self.model.to(self.device) self._initialize_weights()
def __init__(self, cnn_out_channels=(64, 64), cnn_percent_on=(0.095, 0.125), cnn_weight_sparsity=(0.5, 0.2), linear_units=1000, linear_percent_on=0.1, linear_weight_sparsity=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.0, duty_cycle_period=1000, kwinner_local=False): super(GSCSparseCNN, self).__init__() # input_shape = (1, 32, 32) # First Sparse CNN layer if cnn_weight_sparsity[0] < 1.0: self.add_module( "cnn1", SparseWeights2d(nn.Conv2d(1, cnn_out_channels[0], 5), weight_sparsity=cnn_weight_sparsity[0])) else: self.add_module("cnn1", nn.Conv2d(1, cnn_out_channels[0], 5)) self.add_module("cnn1_batchnorm", nn.BatchNorm2d(cnn_out_channels[0], affine=False)) self.add_module( "cnn1_kwinner", KWinners2d( channels=cnn_out_channels[0], percent_on=cnn_percent_on[0], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period, local=kwinner_local, )) self.add_module("cnn1_maxpool", nn.MaxPool2d(2)) # Second Sparse CNN layer if cnn_weight_sparsity[1] < 1.0: self.add_module( "cnn2", SparseWeights2d(nn.Conv2d(cnn_out_channels[0], cnn_out_channels[1], 5), weight_sparsity=cnn_weight_sparsity[1])) else: self.add_module( "cnn2", nn.Conv2d(cnn_out_channels[0], cnn_out_channels[1], 5)) self.add_module("cnn2_batchnorm", nn.BatchNorm2d(cnn_out_channels[1], affine=False)) self.add_module( "cnn2_kwinner", KWinners2d( channels=cnn_out_channels[1], percent_on=cnn_percent_on[1], k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period, local=kwinner_local, )) self.add_module("cnn2_maxpool", nn.MaxPool2d(2)) self.add_module("flatten", Flatten()) # Sparse Linear layer self.add_module( "linear", SparseWeights(nn.Linear(25 * cnn_out_channels[1], linear_units), weight_sparsity=linear_weight_sparsity)) self.add_module("linear_bn", nn.BatchNorm1d(linear_units, affine=False)) self.add_module( "linear_kwinner", KWinners(n=linear_units, percent_on=linear_percent_on, k_inference_factor=k_inference_factor, boost_strength=boost_strength, boost_strength_factor=boost_strength_factor, duty_cycle_period=duty_cycle_period)) # Classifier self.add_module("output", nn.Linear(linear_units, 12)) self.add_module("softmax", nn.LogSoftmax(dim=1))
def __init__( self, input_size, output_size, hidden_sizes, num_segments, dim_context, kw, kw_percent_on=0.05, context_percent_on=1.0, dendrite_weight_sparsity=0.95, weight_sparsity=0.95, weight_init="modified", dendrite_init="modified", freeze_dendrites=False, output_nonlinearity=None, dendritic_layer_class=AbsoluteMaxGatingDendriticLayer, ): # Forward & dendritic weight initialization must be either "kaiming" or # "modified" assert weight_init in ("kaiming", "modified") assert dendrite_init in ("kaiming", "modified") assert kw_percent_on is None or (kw_percent_on >= 0.0 and kw_percent_on < 1.0) assert context_percent_on >= 0.0 if kw_percent_on == 0.0: kw = False super().__init__() if num_segments == 1: # use optimized 1 segment class dendritic_layer_class = OneSegmentDendriticLayer self.num_segments = num_segments self.input_size = input_size self.hidden_sizes = hidden_sizes self.output_size = output_size self.dim_context = dim_context self.kw = kw self.kw_percent_on = kw_percent_on self.weight_sparsity = weight_sparsity self.dendrite_weight_sparsity = dendrite_weight_sparsity self.output_nonlinearity = output_nonlinearity self.hardcode_dendrites = (dendrite_init == "hardcoded") self._layers = nn.ModuleList() self._activations = nn.ModuleList() if self.hardcode_dendrites: dendrite_sparsity = 0.0 else: dendrite_sparsity = self.dendrite_weight_sparsity for i in range(len(self.hidden_sizes)): curr_dend = dendritic_layer_class( module=nn.Linear(input_size, self.hidden_sizes[i], bias=True), num_segments=num_segments, dim_context=dim_context, module_sparsity=self.weight_sparsity, dendrite_sparsity=dendrite_sparsity, ) if weight_init == "modified": # Scale weights to be sampled from the new initialization U(-h, h) where # h = sqrt(1 / (weight_density * previous_layer_percent_on)) if i == 0: # first hidden layer can't have kw input self._init_sparse_weights(curr_dend, 0.0) else: self._init_sparse_weights(curr_dend, 1 - kw_percent_on if kw else 0.0) if dendrite_init == "modified": self._init_sparse_dendrites(curr_dend, 1 - context_percent_on) if freeze_dendrites: # Dendritic weights will not be updated during backward pass for name, param in curr_dend.named_parameters(): if "segments" in name: param.requires_grad = False if self.kw: curr_activation = KWinners(n=hidden_sizes[i], percent_on=kw_percent_on, k_inference_factor=1.0, boost_strength=0.0, boost_strength_factor=0.0) else: curr_activation = nn.ReLU() self._layers.append(curr_dend) self._activations.append(curr_activation) input_size = self.hidden_sizes[i] self._single_output_head = not isinstance(output_size, Iterable) if self._single_output_head: output_size = (output_size, ) self._output_layers = nn.ModuleList() for out_size in output_size: output_layer = nn.Sequential() output_linear = SparseWeights(module=nn.Linear( input_size, out_size), sparsity=weight_sparsity, allow_extremes=True) if weight_init == "modified": self._init_sparse_weights(output_linear, 1 - kw_percent_on if kw else 0.0) output_layer.add_module("output_linear", output_linear) if self.output_nonlinearity is not None: output_layer.add_module("non_linearity", output_nonlinearity) self._output_layers.append(output_layer)