def setup_experiment(self, config): super().setup_experiment(config) # Generate a random context vector if model_type == "dendriticMLP": dim_context = config.get("model_args").get("dim_context") self.context = generate_context_vectors(num_contexts=1, n_dim=dim_context, percent_on=0.05) self.context = self.context.to(self.device)
def test_mean_abs_error(self): """ Ensure mean absolute error retrieved by the hardcoded test is no larger than a chosen epsilon """ # Set `epsilon` to a value that the mean absolute error between the routing # output and dendritic network (with hardcoded dendritic weights) output should # never exceed epsilon = 0.01 # These hyperparameters control the size of the input and output to the routing # function (and dendritic network), the number of dendritic weights, the size # of the context vector, and batch size over which the mean absolute error is # computed dim_in = 100 dim_out = 100 num_contexts = 10 dim_context = 100 batch_size = 100 r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=num_contexts, sparsity=0.7) context_vectors = generate_context_vectors(num_contexts=num_contexts, n_dim=dim_context, percent_on=0.2) module = AbsoluteMaxGatingDendriticLayer(module=r.sparse_weights.module, num_segments=num_contexts, dim_context=dim_context, module_sparsity=0.7, dendrite_sparsity=0.0) module.register_buffer("zero_mask", deepcopy(r.sparse_weights.zero_mask.half())) hardcoded_weights = get_gating_context_weights(output_masks=r.output_masks, context_vectors=context_vectors, num_dendrites=num_contexts) module.segments.weights.data = hardcoded_weights x_test = 4.0 * torch.rand((batch_size, dim_in)) - 2.0 # sampled from U(-2, 2) context_inds_test = randint(low=0, high=num_contexts, size=batch_size).tolist() context_test = torch.stack( [context_vectors[j, :] for j in context_inds_test], dim=0 ) target = r(context_inds_test, x_test) actual = module(x_test, context_test) result = torch.abs(target - actual).mean().item() # Mean absolute error self.assertLess(result, epsilon)
def __init__(self, num_classes, num_tasks, training_examples_per_class, validation_examples_per_class, dim_x, dim_context, root=None, dataset_name=None, train=True): self.num_classes = num_classes self.num_tasks = num_tasks if train: examples_per_class = training_examples_per_class else: examples_per_class = validation_examples_per_class # Initialize disitributions only if `GaussianDataset` object does not exist in # memory if GaussianDataset.means is None: assert GaussianDataset.covs is None assert GaussianDataset._contexts is None GaussianDataset.means = {class_id: torch.rand((dim_x,)) for class_id in range(self.num_classes)} GaussianDataset.covs = {class_id: uniform(0.1, 2.5) * torch.eye(dim_x) for class_id in range(self.num_classes)} self.distributions = {class_id: MultivariateNormal( loc=GaussianDataset.means[class_id], covariance_matrix=GaussianDataset.covs[class_id] ) for class_id in range(self.num_classes)} # Sample i.i.d. from each distribution self.data = {} for class_id in range(self.num_classes): self.data[class_id] = self.distributions[class_id].sample( sample_shape=torch.Size([examples_per_class]) ) self.data = torch.cat([self.data[class_id] for class_id in range(self.num_classes)], dim=0) self.targets = torch.tensor([[class_id for n in range(examples_per_class)] for class_id in range(self.num_classes)]) self.targets = self.targets.flatten() # Context vectors if GaussianDataset._contexts is None: GaussianDataset._contexts = generate_context_vectors(num_contexts=num_tasks, n_dim=dim_context, percent_on=0.05) num_repeats = int(num_classes * examples_per_class / num_tasks) self.contexts = torch.repeat_interleave(GaussianDataset._contexts, repeats=num_repeats, dim=0)
def __init__(self, num_classes, num_examples, context_sparsity, input_dim, context_dim, train=True): # Register attributes self.num_classes = num_classes self.num_examples_per_class = int(num_examples / num_classes) self.num_examples = self.num_examples_per_class * num_classes self.context_sparsity = context_sparsity self.input_dim = input_dim self.context_dim = context_dim # Generate random input vectors self.data = 2.0 * torch.rand((self.num_examples, input_dim)) - 1.0 # Generate targets self.targets = [[class_id for n in range(self.num_examples_per_class)] for class_id in range(num_classes)] self.targets = torch.tensor(self.targets).flatten() # Generate binary context vectors with the desired sparsity percent_on = 1.0 - context_sparsity self.contexts = generate_context_vectors(num_contexts=num_classes, n_dim=context_dim, percent_on=percent_on) assert (self.contexts.sum(dim=1) == int(percent_on * context_dim)).all() self.contexts = torch.repeat_interleave( self.contexts, repeats=self.num_examples_per_class, dim=0) assert self.data.size(0) == self.contexts.size(0) == self.targets.size( 0)
def run_hardcoded_routing_test(dim_in, dim_out, k, dim_context, dendrite_module, context_weights_fn=None, batch_size=100, verbose=False): """ Runs the hardcoded routing test for a specific type of dendritic network :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test network :param k: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module :param context_weights_fn: a function that returns a 3D torch Tensor that gives the near-optimal dendrite values for the specified dendrite_module, and has parameters `output_masks`, `context_vectors`, and `num_dendrites` :param batch_size: the number of test inputs :param verbose: if True, prints target and output values on the first 15 dimensions of batch item 1 """ # Initialize routing function that this task will try to hardcode r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=k, sparsity=0.7) # Initialize context vectors, where each context vector corresponds to an output # mask in the routing function context_vectors = generate_context_vectors(num_contexts=k, n_dim=dim_context, percent_on=0.2) # Initialize dendrite module using the same feed-forward sparse weights as the # routing function; also note that the value passed to `dendrite_sparsity` is # irrelevant since the context weights are subsequently overwritten dendritic_network = dendrite_module(module=r.sparse_weights.module, num_segments=k, dim_context=dim_context, module_sparsity=0.7, dendrite_sparsity=0.0) dendritic_network.register_buffer( "zero_mask", copy.deepcopy(r.sparse_weights.zero_mask.half())) # Choose the context weights specifically so that they can gate the outputs of the # forward module if context_weights_fn is not None: dendritic_network.segments.weights.data = context_weights_fn( output_masks=r.output_masks, context_vectors=context_vectors, num_dendrites=k) # Sample a random batch of inputs and random batch of context vectors, and perform # hardcoded routing test x_test = 4.0 * torch.rand( (batch_size, dim_in)) - 2.0 # sampled from U(-2, 2) context_inds_test = randint(low=0, high=k, size=batch_size).tolist() context_test = torch.stack( [context_vectors[j, :] for j in context_inds_test], dim=0) target = r(context_inds_test, x_test) actual = dendritic_network(x_test, context_test) if verbose: # Print targets and outputs on the first 15 dimensions print("") print(" Element-wise outputs along the first 15 dimensions:") print("") print(" {}{}".format("target".ljust(24), "actual".ljust(24))) for target_i, actual_i in zip(target[0, :15], actual[0, :15]): target_i = str(target_i.item()).ljust(24) actual_i = str(actual_i.item()).ljust(24) print(" {}{}".format(target_i, actual_i)) print(" ...") print("") # Report mean absolute error mean_abs_error = torch.abs(target - actual).mean().item() return {"mean_abs_error": mean_abs_error}
def __init__(self, num_classes, num_tasks, training_examples_per_class, validation_examples_per_class, dim_x, dim_context, seed, root=None, dataset_name=None, train=True): self.num_classes = num_classes self.num_tasks = num_tasks if train: examples_per_class = training_examples_per_class else: examples_per_class = validation_examples_per_class # Use a generator object to manually set the seed and generate the same means # and covariances for both training and validation datasets g = torch.manual_seed(seed) self.means = { class_id: torch.rand((dim_x, ), generator=g) for class_id in range(self.num_classes) } self.covs = { class_id: (2.4 * torch.rand(1, generator=g) + 0.1) * torch.eye(dim_x) for class_id in range(self.num_classes) } self.distributions = { class_id: MultivariateNormal(loc=self.means[class_id], covariance_matrix=self.covs[class_id]) for class_id in range(self.num_classes) } # Sample i.i.d. from each distribution self.data = {} for class_id in range(self.num_classes): self.data[class_id] = self.distributions[class_id].sample( sample_shape=torch.Size([examples_per_class])) self.data = torch.cat( [self.data[class_id] for class_id in range(self.num_classes)], dim=0) self.targets = torch.tensor( [[class_id for n in range(examples_per_class)] for class_id in range(self.num_classes)]) self.targets = self.targets.flatten() # Context vectors self._contexts = generate_context_vectors(num_contexts=num_tasks, n_dim=dim_context, percent_on=0.05, seed=seed) num_repeats = int(num_classes * examples_per_class / num_tasks) self.contexts = torch.repeat_interleave(self._contexts, repeats=num_repeats, dim=0)
def test_regular_network(dim_in, dim_out, num_contexts, dim_context, batch_size=64, num_training_epochs=1000): """ Trains and evalutes a feedforward network with no hidden layers to match an arbitrary routing function :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test network :param num_contexts: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module :param batch_size: the batch size during training and evaluation :param num_training_epochs: the number of epochs for which to train the dendritic network """ # Input size to the model is 2 * dim_in since the context is concatenated with the # regular input model = SingleLayerLinearNetwork(input_size=2 * dim_in, output_size=dim_out) # Initialize routing function that this task will try to learn, and set # `requires_grad=False` since the routing function is static r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=num_contexts, device=model.device, sparsity=0.7) r.sparse_weights.module.weight.requires_grad = False # Initialize context vectors, where each context vector corresponds to an output # mask in the routing function context_vectors = generate_context_vectors(num_contexts=num_contexts, n_dim=dim_context, percent_on=0.2) # Initialize datasets and dataloaders train_dataset = RoutingDataset( routing_function=r, input_size=r.sparse_weights.module.in_features, context_vectors=context_vectors, device=model.device, concat=True, x_min=-2.0, x_max=2.0) train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size) test_dataset = RoutingDataset( routing_function=r, input_size=r.sparse_weights.module.in_features, context_vectors=context_vectors, device=model.device, concat=True, x_min=2.0, x_max=6.0) test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size) # Place objects that inherit from torch.nn.Module on device model = model.to(model.device) r = r.to(r.device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) print("epoch,mean_loss,mean_abs_err") for epoch in range(1, num_training_epochs + 1): train_dendrite_model(model=model, loader=train_dataloader, optimizer=optimizer, device=model.device, criterion=F.l1_loss, concat=True) # Validate model - note that we use a different dataset/dataloader as the input # distribution has changed results = evaluate_dendrite_model(model=model, loader=test_dataloader, device=model.device, criterion=F.l1_loss, concat=True) print("{},{},{}".format(epoch, results["loss"], results["mean_abs_err"]))
def init_test_scenario(mode, dim_in, dim_out, num_contexts, dim_context, dendrite_module): """ Returns the routing function, dendrite layer, context vectors, and device to use in the "learn to route" experiment :param mode: must be one of ("dendrites", "all") "dendrites" -> learn only dendrite weights while setting feed-forward weights to those of the routing function "all" -> learn both feed-forward and dendrite weights :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test layer :param num_contexts: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize routing function that this task will try to hardcode, and set # `requires_grad=False` since the routing function is static r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=num_contexts, device=device, sparsity=0.7) r.sparse_weights.module.weight.requires_grad = False # Initialize context vectors, where each context vector corresponds to an output # mask in the routing function context_vectors = generate_context_vectors(num_contexts=num_contexts, n_dim=dim_context, percent_on=0.2) # If only training the dendrite weights, initialize the dendrite module using the # same feed-forward sparse weights as the routing function, otherwise if learning # feed-forward weights, use `torch.nn.Linear` # Also, note that the value passed to `dendrite_sparsity` is irrelevant since the # context weights are subsequently # overwritten if mode == "dendrites": dendrite_layer_forward_module = r.sparse_weights.module elif mode == "all": dendrite_layer_forward_module = torch.nn.Linear(dim_in, dim_out, bias=False) else: raise Exception("Invalid value for `mode`: {}".format(mode)) dendrite_layer = dendrite_module(module=dendrite_layer_forward_module, num_segments=num_contexts, dim_context=dim_context, module_sparsity=0.7, dendrite_sparsity=0.0) # In this version of learning to route, there is no sparsity constraint on the # dendrite weights dendrite_layer.register_buffer( "zero_mask", torch.ones(dendrite_layer.zero_mask.shape).half()) # Place objects that inherit from torch.nn.Module on device r = r.to(device) dendrite_layer = dendrite_layer.to(device) context_vectors = context_vectors.to(device) return r, dendrite_layer, context_vectors, device