def setup_experiment(self, config):
        super().setup_experiment(config)

        # Generate a random context vector
        if model_type == "dendriticMLP":
            dim_context = config.get("model_args").get("dim_context")
            self.context = generate_context_vectors(num_contexts=1, n_dim=dim_context,
                                                    percent_on=0.05)
            self.context = self.context.to(self.device)
    def test_mean_abs_error(self):
        """
        Ensure mean absolute error retrieved by the hardcoded test is no larger than a
        chosen epsilon
        """

        # Set `epsilon` to a value that the mean absolute error between the routing
        # output and dendritic network (with hardcoded dendritic weights) output should
        # never exceed
        epsilon = 0.01

        # These hyperparameters control the size of the input and output to the routing
        # function (and dendritic network), the number of dendritic weights, the size
        # of the context vector, and batch size over which the mean absolute error is
        # computed
        dim_in = 100
        dim_out = 100
        num_contexts = 10
        dim_context = 100
        batch_size = 100

        r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=num_contexts,
                            sparsity=0.7)

        context_vectors = generate_context_vectors(num_contexts=num_contexts,
                                                   n_dim=dim_context,
                                                   percent_on=0.2)

        module = AbsoluteMaxGatingDendriticLayer(module=r.sparse_weights.module,
                                                 num_segments=num_contexts,
                                                 dim_context=dim_context,
                                                 module_sparsity=0.7,
                                                 dendrite_sparsity=0.0)

        module.register_buffer("zero_mask",
                               deepcopy(r.sparse_weights.zero_mask.half()))

        hardcoded_weights = get_gating_context_weights(output_masks=r.output_masks,
                                                       context_vectors=context_vectors,
                                                       num_dendrites=num_contexts)
        module.segments.weights.data = hardcoded_weights

        x_test = 4.0 * torch.rand((batch_size, dim_in)) - 2.0  # sampled from U(-2, 2)
        context_inds_test = randint(low=0, high=num_contexts, size=batch_size).tolist()
        context_test = torch.stack(
            [context_vectors[j, :] for j in context_inds_test],
            dim=0
        )

        target = r(context_inds_test, x_test)
        actual = module(x_test, context_test)

        result = torch.abs(target - actual).mean().item()  # Mean absolute error
        self.assertLess(result, epsilon)
Beispiel #3
0
    def __init__(self, num_classes, num_tasks, training_examples_per_class,
                 validation_examples_per_class, dim_x, dim_context,
                 root=None, dataset_name=None, train=True):

        self.num_classes = num_classes
        self.num_tasks = num_tasks
        if train:
            examples_per_class = training_examples_per_class
        else:
            examples_per_class = validation_examples_per_class

        # Initialize disitributions only if `GaussianDataset` object does not exist in
        # memory
        if GaussianDataset.means is None:

            assert GaussianDataset.covs is None
            assert GaussianDataset._contexts is None

            GaussianDataset.means = {class_id: torch.rand((dim_x,)) for class_id in
                                     range(self.num_classes)}
            GaussianDataset.covs = {class_id: uniform(0.1, 2.5) * torch.eye(dim_x) for
                                    class_id in range(self.num_classes)}

        self.distributions = {class_id: MultivariateNormal(
            loc=GaussianDataset.means[class_id],
            covariance_matrix=GaussianDataset.covs[class_id]
        ) for class_id in range(self.num_classes)}

        # Sample i.i.d. from each distribution
        self.data = {}
        for class_id in range(self.num_classes):
            self.data[class_id] = self.distributions[class_id].sample(
                sample_shape=torch.Size([examples_per_class])
            )
        self.data = torch.cat([self.data[class_id] for class_id in
                               range(self.num_classes)], dim=0)

        self.targets = torch.tensor([[class_id for n in range(examples_per_class)]
                                     for class_id in range(self.num_classes)])
        self.targets = self.targets.flatten()

        # Context vectors
        if GaussianDataset._contexts is None:
            GaussianDataset._contexts = generate_context_vectors(num_contexts=num_tasks,
                                                                 n_dim=dim_context,
                                                                 percent_on=0.05)

        num_repeats = int(num_classes * examples_per_class / num_tasks)
        self.contexts = torch.repeat_interleave(GaussianDataset._contexts,
                                                repeats=num_repeats, dim=0)
    def __init__(self,
                 num_classes,
                 num_examples,
                 context_sparsity,
                 input_dim,
                 context_dim,
                 train=True):

        # Register attributes
        self.num_classes = num_classes
        self.num_examples_per_class = int(num_examples / num_classes)
        self.num_examples = self.num_examples_per_class * num_classes
        self.context_sparsity = context_sparsity
        self.input_dim = input_dim
        self.context_dim = context_dim

        # Generate random input vectors
        self.data = 2.0 * torch.rand((self.num_examples, input_dim)) - 1.0

        # Generate targets
        self.targets = [[class_id for n in range(self.num_examples_per_class)]
                        for class_id in range(num_classes)]
        self.targets = torch.tensor(self.targets).flatten()

        # Generate binary context vectors with the desired sparsity
        percent_on = 1.0 - context_sparsity
        self.contexts = generate_context_vectors(num_contexts=num_classes,
                                                 n_dim=context_dim,
                                                 percent_on=percent_on)
        assert (self.contexts.sum(dim=1) == int(percent_on *
                                                context_dim)).all()

        self.contexts = torch.repeat_interleave(
            self.contexts, repeats=self.num_examples_per_class, dim=0)

        assert self.data.size(0) == self.contexts.size(0) == self.targets.size(
            0)
Beispiel #5
0
def run_hardcoded_routing_test(dim_in,
                               dim_out,
                               k,
                               dim_context,
                               dendrite_module,
                               context_weights_fn=None,
                               batch_size=100,
                               verbose=False):
    """
    Runs the hardcoded routing test for a specific type of dendritic network

    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test network
    :param k: the number of unique random binary vectors in the routing function that
              can "route" the sparse linear output, and also the number of unique
              context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    :param context_weights_fn: a function that returns a 3D torch Tensor that gives the
                               near-optimal dendrite values for the specified
                               dendrite_module, and has parameters `output_masks`,
                               `context_vectors`, and `num_dendrites`
    :param batch_size: the number of test inputs
    :param verbose: if True, prints target and output values on the first 15 dimensions
                    of batch item 1
    """

    # Initialize routing function that this task will try to hardcode
    r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=k, sparsity=0.7)

    # Initialize context vectors, where each context vector corresponds to an output
    # mask in the routing function
    context_vectors = generate_context_vectors(num_contexts=k,
                                               n_dim=dim_context,
                                               percent_on=0.2)

    # Initialize dendrite module using the same feed-forward sparse weights as the
    # routing function; also note that the value passed to `dendrite_sparsity` is
    # irrelevant since the context weights are subsequently overwritten
    dendritic_network = dendrite_module(module=r.sparse_weights.module,
                                        num_segments=k,
                                        dim_context=dim_context,
                                        module_sparsity=0.7,
                                        dendrite_sparsity=0.0)

    dendritic_network.register_buffer(
        "zero_mask", copy.deepcopy(r.sparse_weights.zero_mask.half()))

    # Choose the context weights specifically so that they can gate the outputs of the
    # forward module
    if context_weights_fn is not None:
        dendritic_network.segments.weights.data = context_weights_fn(
            output_masks=r.output_masks,
            context_vectors=context_vectors,
            num_dendrites=k)

    # Sample a random batch of inputs and random batch of context vectors, and perform
    # hardcoded routing test
    x_test = 4.0 * torch.rand(
        (batch_size, dim_in)) - 2.0  # sampled from U(-2, 2)
    context_inds_test = randint(low=0, high=k, size=batch_size).tolist()
    context_test = torch.stack(
        [context_vectors[j, :] for j in context_inds_test], dim=0)

    target = r(context_inds_test, x_test)
    actual = dendritic_network(x_test, context_test)

    if verbose:

        # Print targets and outputs on the first 15 dimensions
        print("")
        print(" Element-wise outputs along the first 15 dimensions:")
        print("")
        print(" {}{}".format("target".ljust(24), "actual".ljust(24)))
        for target_i, actual_i in zip(target[0, :15], actual[0, :15]):

            target_i = str(target_i.item()).ljust(24)
            actual_i = str(actual_i.item()).ljust(24)

            print(" {}{}".format(target_i, actual_i))
        print(" ...")
        print("")

    # Report mean absolute error
    mean_abs_error = torch.abs(target - actual).mean().item()
    return {"mean_abs_error": mean_abs_error}
Beispiel #6
0
    def __init__(self,
                 num_classes,
                 num_tasks,
                 training_examples_per_class,
                 validation_examples_per_class,
                 dim_x,
                 dim_context,
                 seed,
                 root=None,
                 dataset_name=None,
                 train=True):

        self.num_classes = num_classes
        self.num_tasks = num_tasks
        if train:
            examples_per_class = training_examples_per_class
        else:
            examples_per_class = validation_examples_per_class

        # Use a generator object to manually set the seed and generate the same means
        # and covariances for both training and validation datasets
        g = torch.manual_seed(seed)

        self.means = {
            class_id: torch.rand((dim_x, ), generator=g)
            for class_id in range(self.num_classes)
        }
        self.covs = {
            class_id:
            (2.4 * torch.rand(1, generator=g) + 0.1) * torch.eye(dim_x)
            for class_id in range(self.num_classes)
        }

        self.distributions = {
            class_id: MultivariateNormal(loc=self.means[class_id],
                                         covariance_matrix=self.covs[class_id])
            for class_id in range(self.num_classes)
        }

        # Sample i.i.d. from each distribution
        self.data = {}
        for class_id in range(self.num_classes):
            self.data[class_id] = self.distributions[class_id].sample(
                sample_shape=torch.Size([examples_per_class]))
        self.data = torch.cat(
            [self.data[class_id] for class_id in range(self.num_classes)],
            dim=0)

        self.targets = torch.tensor(
            [[class_id for n in range(examples_per_class)]
             for class_id in range(self.num_classes)])
        self.targets = self.targets.flatten()

        # Context vectors
        self._contexts = generate_context_vectors(num_contexts=num_tasks,
                                                  n_dim=dim_context,
                                                  percent_on=0.05,
                                                  seed=seed)
        num_repeats = int(num_classes * examples_per_class / num_tasks)
        self.contexts = torch.repeat_interleave(self._contexts,
                                                repeats=num_repeats,
                                                dim=0)
def test_regular_network(dim_in,
                         dim_out,
                         num_contexts,
                         dim_context,
                         batch_size=64,
                         num_training_epochs=1000):
    """
    Trains and evalutes a feedforward network with no hidden layers to match an
    arbitrary routing function

    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test network
    :param num_contexts: the number of unique random binary vectors in the routing
                         function that can "route" the sparse linear output, and also
                         the number of unique context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    :param batch_size: the batch size during training and evaluation
    :param num_training_epochs: the number of epochs for which to train the dendritic
                                network
    """

    # Input size to the model is 2 * dim_in since the context is concatenated with the
    # regular input
    model = SingleLayerLinearNetwork(input_size=2 * dim_in,
                                     output_size=dim_out)

    # Initialize routing function that this task will try to learn, and set
    # `requires_grad=False` since the routing function is static
    r = RoutingFunction(dim_in=dim_in,
                        dim_out=dim_out,
                        k=num_contexts,
                        device=model.device,
                        sparsity=0.7)
    r.sparse_weights.module.weight.requires_grad = False

    # Initialize context vectors, where each context vector corresponds to an output
    # mask in the routing function
    context_vectors = generate_context_vectors(num_contexts=num_contexts,
                                               n_dim=dim_context,
                                               percent_on=0.2)

    # Initialize datasets and dataloaders
    train_dataset = RoutingDataset(
        routing_function=r,
        input_size=r.sparse_weights.module.in_features,
        context_vectors=context_vectors,
        device=model.device,
        concat=True,
        x_min=-2.0,
        x_max=2.0)

    train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size)

    test_dataset = RoutingDataset(
        routing_function=r,
        input_size=r.sparse_weights.module.in_features,
        context_vectors=context_vectors,
        device=model.device,
        concat=True,
        x_min=2.0,
        x_max=6.0)

    test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size)

    # Place objects that inherit from torch.nn.Module on device
    model = model.to(model.device)
    r = r.to(r.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("epoch,mean_loss,mean_abs_err")
    for epoch in range(1, num_training_epochs + 1):

        train_dendrite_model(model=model,
                             loader=train_dataloader,
                             optimizer=optimizer,
                             device=model.device,
                             criterion=F.l1_loss,
                             concat=True)

        # Validate model - note that we use a different dataset/dataloader as the input
        # distribution has changed
        results = evaluate_dendrite_model(model=model,
                                          loader=test_dataloader,
                                          device=model.device,
                                          criterion=F.l1_loss,
                                          concat=True)

        print("{},{},{}".format(epoch, results["loss"],
                                results["mean_abs_err"]))
Beispiel #8
0
def init_test_scenario(mode, dim_in, dim_out, num_contexts, dim_context,
                       dendrite_module):
    """
    Returns the routing function, dendrite layer, context vectors, and device to use
    in the "learn to route" experiment

    :param mode: must be one of ("dendrites", "all")
                 "dendrites" -> learn only dendrite weights while setting feed-forward
                 weights to those of the routing function
                 "all" -> learn both feed-forward and dendrite weights
    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test layer
    :param num_contexts: the number of unique random binary vectors in the routing
                         function that can "route" the sparse linear output, and also
                         the number of unique context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize routing function that this task will try to hardcode, and set
    # `requires_grad=False` since the routing function is static
    r = RoutingFunction(dim_in=dim_in,
                        dim_out=dim_out,
                        k=num_contexts,
                        device=device,
                        sparsity=0.7)
    r.sparse_weights.module.weight.requires_grad = False

    # Initialize context vectors, where each context vector corresponds to an output
    # mask in the routing function
    context_vectors = generate_context_vectors(num_contexts=num_contexts,
                                               n_dim=dim_context,
                                               percent_on=0.2)

    # If only training the dendrite weights, initialize the dendrite module using the
    # same feed-forward sparse weights as the routing function, otherwise if learning
    # feed-forward weights, use `torch.nn.Linear`

    # Also, note that the value passed to `dendrite_sparsity` is irrelevant since the
    # context weights are subsequently
    # overwritten
    if mode == "dendrites":
        dendrite_layer_forward_module = r.sparse_weights.module
    elif mode == "all":
        dendrite_layer_forward_module = torch.nn.Linear(dim_in,
                                                        dim_out,
                                                        bias=False)
    else:
        raise Exception("Invalid value for `mode`: {}".format(mode))

    dendrite_layer = dendrite_module(module=dendrite_layer_forward_module,
                                     num_segments=num_contexts,
                                     dim_context=dim_context,
                                     module_sparsity=0.7,
                                     dendrite_sparsity=0.0)

    # In this version of learning to route, there is no sparsity constraint on the
    # dendrite weights
    dendrite_layer.register_buffer(
        "zero_mask",
        torch.ones(dendrite_layer.zero_mask.shape).half())

    # Place objects that inherit from torch.nn.Module on device
    r = r.to(device)
    dendrite_layer = dendrite_layer.to(device)
    context_vectors = context_vectors.to(device)

    return r, dendrite_layer, context_vectors, device