예제 #1
0
    def test_forward_weights(self):
        """
        Feed-forward weights should not be modified when `mode == "dendrites"`, and
        should be modified when `mode == "all"`
        """

        for mode in ("dendrites", "all", "learn_context"):

            r, dendrite_layer, context_model, context_vectors, device = \
                init_test_scenario(
                    mode=mode,
                    dim_in=100,
                    dim_out=100,
                    num_contexts=10,
                    dim_context=100,
                    dendrite_module=AbsoluteMaxGatingDendriticLayer
                )

            dataloader = init_dataloader(routing_function=r,
                                         context_vectors=context_vectors,
                                         device=device,
                                         batch_size=64,
                                         x_min=-2.0,
                                         x_max=2.0)

            optimizer = init_optimizer(mode=mode,
                                       layer=dendrite_layer,
                                       context_model=context_model)

            forward_weights_before = copy.deepcopy(
                dendrite_layer.module.weight.data)

            # Perform a single training epoch
            train_dendrite_model(model=dendrite_layer,
                                 context_model=context_model,
                                 loader=dataloader,
                                 optimizer=optimizer,
                                 device=device,
                                 criterion=F.l1_loss)

            forward_weights_after = copy.deepcopy(
                dendrite_layer.module.weight.data)

            expected = (forward_weights_before == forward_weights_after).all()

            # If training both feed-forward and dendrite weights, we expect the
            # dendrite weights to change
            if mode == "all" or "learn_context":
                expected = not expected

            self.assertTrue(expected)
예제 #2
0
    def test_dendrite_weights(self):
        """
        Dendrite weights should be modified both when `mode == "dendrites"` and when
        `mode == "all"`
        """

        for mode in ("dendrites", "all", "learn_context"):

            r, dendrite_layer, context_model, context_vectors, device = \
                init_test_scenario(
                    mode=mode,
                    dim_in=100,
                    dim_out=100,
                    num_contexts=10,
                    dim_context=100,
                    dendrite_module=AbsoluteMaxGatingDendriticLayer
                )

            dataloader = init_dataloader(routing_function=r,
                                         context_vectors=context_vectors,
                                         device=device,
                                         batch_size=64,
                                         x_min=-2.0,
                                         x_max=2.0)

            optimizer = init_optimizer(mode=mode,
                                       layer=dendrite_layer,
                                       context_model=context_model)

            dendrite_weights_before = copy.deepcopy(
                dendrite_layer.segments.weights.data)

            # Perform a single training epoch
            train_dendrite_model(model=dendrite_layer,
                                 context_model=context_model,
                                 loader=dataloader,
                                 optimizer=optimizer,
                                 device=device,
                                 criterion=F.l1_loss)

            dendrite_weights_after = copy.deepcopy(
                dendrite_layer.segments.weights.data)

            expected = (dendrite_weights_before !=
                        dendrite_weights_after).any()
            self.assertTrue(expected)
예제 #3
0
    def test_context_model(self):
        """
        Context model should be learned only when `mode == "learn_context"`
        """
        for mode in ("dendrites", "all", "learn_context"):

            r, dendrite_layer, context_model, context_vectors, device =\
                init_test_scenario(
                    mode=mode,
                    dim_in=100,
                    dim_out=100,
                    num_contexts=10,
                    dim_context=100,
                    dendrite_module=AbsoluteMaxGatingDendriticLayer
                )

            dataloader = init_dataloader(routing_function=r,
                                         context_vectors=context_vectors,
                                         device=device,
                                         batch_size=64,
                                         x_min=-2.0,
                                         x_max=2.0)

            optimizer = init_optimizer(mode=mode,
                                       layer=dendrite_layer,
                                       context_model=context_model)

            if mode != "learn_context":
                assert context_model is None
            else:
                context_weights_before = copy.deepcopy(
                    context_model.linear1.weight.data)
                # Perform a single training epoch
                train_dendrite_model(model=dendrite_layer,
                                     context_model=context_model,
                                     loader=dataloader,
                                     optimizer=optimizer,
                                     device=device,
                                     criterion=F.l1_loss)

                context_weights_after = copy.deepcopy(
                    context_model.linear1.weight.data)

                expected = (context_weights_before !=
                            context_weights_after).any()
                self.assertTrue(expected)
예제 #4
0
def test_regular_network(dim_in,
                         dim_out,
                         num_contexts,
                         dim_context,
                         batch_size=64,
                         num_training_epochs=1000):
    """
    Trains and evalutes a feedforward network with no hidden layers to match an
    arbitrary routing function

    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test network
    :param num_contexts: the number of unique random binary vectors in the routing
                         function that can "route" the sparse linear output, and also
                         the number of unique context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    :param batch_size: the batch size during training and evaluation
    :param num_training_epochs: the number of epochs for which to train the dendritic
                                network
    """

    # Input size to the model is 2 * dim_in since the context is concatenated with the
    # regular input
    model = SingleLayerLinearNetwork(input_size=2 * dim_in,
                                     output_size=dim_out)

    # Initialize routing function that this task will try to learn, and set
    # `requires_grad=False` since the routing function is static
    r = RoutingFunction(dim_in=dim_in,
                        dim_out=dim_out,
                        k=num_contexts,
                        device=model.device,
                        sparsity=0.7)
    r.sparse_weights.module.weight.requires_grad = False

    # Initialize context vectors, where each context vector corresponds to an output
    # mask in the routing function
    context_vectors = generate_context_vectors(num_contexts=num_contexts,
                                               n_dim=dim_context,
                                               percent_on=0.2)

    # Initialize datasets and dataloaders
    train_dataset = RoutingDataset(
        routing_function=r,
        input_size=r.sparse_weights.module.in_features,
        context_vectors=context_vectors,
        device=model.device,
        concat=True,
        x_min=-2.0,
        x_max=2.0)

    train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size)

    test_dataset = RoutingDataset(
        routing_function=r,
        input_size=r.sparse_weights.module.in_features,
        context_vectors=context_vectors,
        device=model.device,
        concat=True,
        x_min=2.0,
        x_max=6.0)

    test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size)

    # Place objects that inherit from torch.nn.Module on device
    model = model.to(model.device)
    r = r.to(r.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("epoch,mean_loss,mean_abs_err")
    for epoch in range(1, num_training_epochs + 1):

        train_dendrite_model(model=model,
                             loader=train_dataloader,
                             optimizer=optimizer,
                             device=model.device,
                             criterion=F.l1_loss,
                             concat=True)

        # Validate model - note that we use a different dataset/dataloader as the input
        # distribution has changed
        results = evaluate_dendrite_model(model=model,
                                          loader=test_dataloader,
                                          device=model.device,
                                          criterion=F.l1_loss,
                                          concat=True)

        print("{},{},{}".format(epoch, results["loss"],
                                results["mean_abs_err"]))
예제 #5
0
def learn_to_route(mode,
                   dim_in,
                   dim_out,
                   num_contexts,
                   dim_context,
                   dendrite_module,
                   batch_size=64,
                   num_training_epochs=5000,
                   sparse_context_model=True,
                   onehot=False,
                   plot=False,
                   save_interval=100,
                   save_path="./models/"):
    """
    Trains a dendrite layer to match an arbitrary routing function

    :param mode: must be one of ("dendrites", "all", "learn_context)
                 "dendrites" -> learn only dendrite weights while setting feed-forward
                 weights to those of the routing function
                 "all" -> learn both feed-forward and dendrite weights
                 "learn_context" -> learn feed-forward, dendrite, and context-generation
                 weights
    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test layer
    :param num_contexts: the number of unique random binary vectors in the routing
                         function that can "route" the sparse linear output, and also
                         the number of unique context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    :param batch_size: the batch size during training and evaluation
    :param num_training_epochs: the number of epochs for which to train the dendrite
                                layer
    :param sparse_context_model: whether to use a sparse MLP to generate the context;
                                 applicable if mode == "learn_context"
    :param onehot: whether the context integer should be encoded as a onehot vector
                   when input into the context generation model
    :param plot: whether to plot a loss curve
    :param save_interval: number of epochs between saving the model
    :param save_path: path to folder in which to save model checkpoints.
    """

    r, dendrite_layer, context_model, context_vectors, device = init_test_scenario(
        mode=mode,
        dim_in=dim_in,
        dim_out=dim_out,
        num_contexts=num_contexts,
        dim_context=dim_context,
        dendrite_module=dendrite_module,
        sparse_context_model=sparse_context_model,
        onehot=onehot)

    train_dataloader = init_dataloader(
        routing_function=r,
        context_vectors=context_vectors,
        device=device,
        batch_size=batch_size,
        x_min=-2.0,
        x_max=2.0,
    )

    test_dataloader = init_dataloader(
        routing_function=r,
        context_vectors=context_vectors,
        device=device,
        batch_size=batch_size,
        x_min=2.0,
        x_max=6.0,
    )

    optimizer = init_optimizer(mode=mode,
                               layer=dendrite_layer,
                               context_model=context_model)

    print("epoch,mean_loss,mean_abs_err")
    losses = []

    for epoch in range(1, num_training_epochs + 1):

        l1_weight_decay = None
        # Select L1 weight decay penalty based on scenario
        if mode == "dendrites":
            l1_weight_decay = 0.0
        elif mode == "all" or mode == "learn_context":
            l1_weight_decay = 1e-6

        train_dendrite_model(
            model=dendrite_layer,
            context_model=context_model,
            loader=train_dataloader,
            optimizer=optimizer,
            device=device,
            criterion=F.l1_loss,
            concat=False,
            l1_weight_decay=l1_weight_decay,
        )

        # Validate model
        results = evaluate_dendrite_model(model=dendrite_layer,
                                          context_model=context_model,
                                          loader=test_dataloader,
                                          device=device,
                                          criterion=F.l1_loss,
                                          concat=False)

        print("{},{}".format(epoch, results["mean_abs_err"]))
        # track loss for plotting
        if plot:
            losses.append(results["mean_abs_err"])
        # save models for future visualization or training
        if save_interval and epoch % save_interval == 0:
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            torch.save(dendrite_layer.state_dict(),
                       save_path + "dendrite_" + str(epoch))
            if context_model:
                torch.save(context_model.state_dict(),
                           save_path + "context_" + str(epoch))

    if plot:
        losses = np.array(losses)
        plt.scatter(x=np.arange(1, num_training_epochs + 1), y=losses)
        plt.savefig("training_curve.png")
예제 #6
0
def learn_to_route(mode,
                   dim_in,
                   dim_out,
                   num_contexts,
                   dim_context,
                   dendrite_module,
                   batch_size=64,
                   num_training_epochs=5000):
    """
    Trains a dendrite layer to match an arbitrary routing function

    :param mode: must be one of ("dendrites", "all")
                 "dendrites" -> learn only dendrite weights while setting feed-forward
                 weights to those of the routing function
                 "all" -> learn both feed-forward and dendrite weights
    :param dim_in: the number of dimensions in the input to the routing function and
                   test module
    :param dim_out: the number of dimensions in the sparse linear output of the routing
                    function and test layer
    :param num_contexts: the number of unique random binary vectors in the routing
                         function that can "route" the sparse linear output, and also
                         the number of unique context vectors
    :param dim_context: the number of dimensions in the context vectors
    :param dendrite_module: a torch.nn.Module subclass that implements a dendrite
                            module in addition to a linear feed-forward module
    :param batch_size: the batch size during training and evaluation
    :param num_training_epochs: the number of epochs for which to train the dendrite
                                layer
    """

    r, dendrite_layer, context_vectors, device = init_test_scenario(
        mode=mode,
        dim_in=dim_in,
        dim_out=dim_out,
        num_contexts=num_contexts,
        dim_context=dim_context,
        dendrite_module=dendrite_module)

    train_dataloader = init_dataloader(
        routing_function=r,
        context_vectors=context_vectors,
        device=device,
        batch_size=batch_size,
        x_min=-2.0,
        x_max=2.0,
    )

    test_dataloader = init_dataloader(
        routing_function=r,
        context_vectors=context_vectors,
        device=device,
        batch_size=batch_size,
        x_min=2.0,
        x_max=6.0,
    )

    optimizer = init_optimizer(mode=mode, layer=dendrite_layer)

    print("epoch,mean_loss,mean_abs_err")
    for epoch in range(1, num_training_epochs + 1):

        # Select L1 weight decay penalty based on scenario
        if mode == "dendrites":
            l1_weight_decay = 0.0
        elif mode == "all":
            l1_weight_decay = 1e-6

        train_dendrite_model(model=dendrite_layer,
                             loader=train_dataloader,
                             optimizer=optimizer,
                             device=device,
                             criterion=F.l1_loss,
                             concat=False,
                             l1_weight_decay=l1_weight_decay)

        # Validate model
        results = evaluate_dendrite_model(model=dendrite_layer,
                                          loader=test_dataloader,
                                          device=device,
                                          criterion=F.l1_loss,
                                          concat=False)

        print("{},{}".format(epoch, results["mean_abs_err"]))