def test_forward_weights(self): """ Feed-forward weights should not be modified when `mode == "dendrites"`, and should be modified when `mode == "all"` """ for mode in ("dendrites", "all", "learn_context"): r, dendrite_layer, context_model, context_vectors, device = \ init_test_scenario( mode=mode, dim_in=100, dim_out=100, num_contexts=10, dim_context=100, dendrite_module=AbsoluteMaxGatingDendriticLayer ) dataloader = init_dataloader(routing_function=r, context_vectors=context_vectors, device=device, batch_size=64, x_min=-2.0, x_max=2.0) optimizer = init_optimizer(mode=mode, layer=dendrite_layer, context_model=context_model) forward_weights_before = copy.deepcopy( dendrite_layer.module.weight.data) # Perform a single training epoch train_dendrite_model(model=dendrite_layer, context_model=context_model, loader=dataloader, optimizer=optimizer, device=device, criterion=F.l1_loss) forward_weights_after = copy.deepcopy( dendrite_layer.module.weight.data) expected = (forward_weights_before == forward_weights_after).all() # If training both feed-forward and dendrite weights, we expect the # dendrite weights to change if mode == "all" or "learn_context": expected = not expected self.assertTrue(expected)
def test_dendrite_weights(self): """ Dendrite weights should be modified both when `mode == "dendrites"` and when `mode == "all"` """ for mode in ("dendrites", "all", "learn_context"): r, dendrite_layer, context_model, context_vectors, device = \ init_test_scenario( mode=mode, dim_in=100, dim_out=100, num_contexts=10, dim_context=100, dendrite_module=AbsoluteMaxGatingDendriticLayer ) dataloader = init_dataloader(routing_function=r, context_vectors=context_vectors, device=device, batch_size=64, x_min=-2.0, x_max=2.0) optimizer = init_optimizer(mode=mode, layer=dendrite_layer, context_model=context_model) dendrite_weights_before = copy.deepcopy( dendrite_layer.segments.weights.data) # Perform a single training epoch train_dendrite_model(model=dendrite_layer, context_model=context_model, loader=dataloader, optimizer=optimizer, device=device, criterion=F.l1_loss) dendrite_weights_after = copy.deepcopy( dendrite_layer.segments.weights.data) expected = (dendrite_weights_before != dendrite_weights_after).any() self.assertTrue(expected)
def test_context_model(self): """ Context model should be learned only when `mode == "learn_context"` """ for mode in ("dendrites", "all", "learn_context"): r, dendrite_layer, context_model, context_vectors, device =\ init_test_scenario( mode=mode, dim_in=100, dim_out=100, num_contexts=10, dim_context=100, dendrite_module=AbsoluteMaxGatingDendriticLayer ) dataloader = init_dataloader(routing_function=r, context_vectors=context_vectors, device=device, batch_size=64, x_min=-2.0, x_max=2.0) optimizer = init_optimizer(mode=mode, layer=dendrite_layer, context_model=context_model) if mode != "learn_context": assert context_model is None else: context_weights_before = copy.deepcopy( context_model.linear1.weight.data) # Perform a single training epoch train_dendrite_model(model=dendrite_layer, context_model=context_model, loader=dataloader, optimizer=optimizer, device=device, criterion=F.l1_loss) context_weights_after = copy.deepcopy( context_model.linear1.weight.data) expected = (context_weights_before != context_weights_after).any() self.assertTrue(expected)
def test_regular_network(dim_in, dim_out, num_contexts, dim_context, batch_size=64, num_training_epochs=1000): """ Trains and evalutes a feedforward network with no hidden layers to match an arbitrary routing function :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test network :param num_contexts: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module :param batch_size: the batch size during training and evaluation :param num_training_epochs: the number of epochs for which to train the dendritic network """ # Input size to the model is 2 * dim_in since the context is concatenated with the # regular input model = SingleLayerLinearNetwork(input_size=2 * dim_in, output_size=dim_out) # Initialize routing function that this task will try to learn, and set # `requires_grad=False` since the routing function is static r = RoutingFunction(dim_in=dim_in, dim_out=dim_out, k=num_contexts, device=model.device, sparsity=0.7) r.sparse_weights.module.weight.requires_grad = False # Initialize context vectors, where each context vector corresponds to an output # mask in the routing function context_vectors = generate_context_vectors(num_contexts=num_contexts, n_dim=dim_context, percent_on=0.2) # Initialize datasets and dataloaders train_dataset = RoutingDataset( routing_function=r, input_size=r.sparse_weights.module.in_features, context_vectors=context_vectors, device=model.device, concat=True, x_min=-2.0, x_max=2.0) train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size) test_dataset = RoutingDataset( routing_function=r, input_size=r.sparse_weights.module.in_features, context_vectors=context_vectors, device=model.device, concat=True, x_min=2.0, x_max=6.0) test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size) # Place objects that inherit from torch.nn.Module on device model = model.to(model.device) r = r.to(r.device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) print("epoch,mean_loss,mean_abs_err") for epoch in range(1, num_training_epochs + 1): train_dendrite_model(model=model, loader=train_dataloader, optimizer=optimizer, device=model.device, criterion=F.l1_loss, concat=True) # Validate model - note that we use a different dataset/dataloader as the input # distribution has changed results = evaluate_dendrite_model(model=model, loader=test_dataloader, device=model.device, criterion=F.l1_loss, concat=True) print("{},{},{}".format(epoch, results["loss"], results["mean_abs_err"]))
def learn_to_route(mode, dim_in, dim_out, num_contexts, dim_context, dendrite_module, batch_size=64, num_training_epochs=5000, sparse_context_model=True, onehot=False, plot=False, save_interval=100, save_path="./models/"): """ Trains a dendrite layer to match an arbitrary routing function :param mode: must be one of ("dendrites", "all", "learn_context) "dendrites" -> learn only dendrite weights while setting feed-forward weights to those of the routing function "all" -> learn both feed-forward and dendrite weights "learn_context" -> learn feed-forward, dendrite, and context-generation weights :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test layer :param num_contexts: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module :param batch_size: the batch size during training and evaluation :param num_training_epochs: the number of epochs for which to train the dendrite layer :param sparse_context_model: whether to use a sparse MLP to generate the context; applicable if mode == "learn_context" :param onehot: whether the context integer should be encoded as a onehot vector when input into the context generation model :param plot: whether to plot a loss curve :param save_interval: number of epochs between saving the model :param save_path: path to folder in which to save model checkpoints. """ r, dendrite_layer, context_model, context_vectors, device = init_test_scenario( mode=mode, dim_in=dim_in, dim_out=dim_out, num_contexts=num_contexts, dim_context=dim_context, dendrite_module=dendrite_module, sparse_context_model=sparse_context_model, onehot=onehot) train_dataloader = init_dataloader( routing_function=r, context_vectors=context_vectors, device=device, batch_size=batch_size, x_min=-2.0, x_max=2.0, ) test_dataloader = init_dataloader( routing_function=r, context_vectors=context_vectors, device=device, batch_size=batch_size, x_min=2.0, x_max=6.0, ) optimizer = init_optimizer(mode=mode, layer=dendrite_layer, context_model=context_model) print("epoch,mean_loss,mean_abs_err") losses = [] for epoch in range(1, num_training_epochs + 1): l1_weight_decay = None # Select L1 weight decay penalty based on scenario if mode == "dendrites": l1_weight_decay = 0.0 elif mode == "all" or mode == "learn_context": l1_weight_decay = 1e-6 train_dendrite_model( model=dendrite_layer, context_model=context_model, loader=train_dataloader, optimizer=optimizer, device=device, criterion=F.l1_loss, concat=False, l1_weight_decay=l1_weight_decay, ) # Validate model results = evaluate_dendrite_model(model=dendrite_layer, context_model=context_model, loader=test_dataloader, device=device, criterion=F.l1_loss, concat=False) print("{},{}".format(epoch, results["mean_abs_err"])) # track loss for plotting if plot: losses.append(results["mean_abs_err"]) # save models for future visualization or training if save_interval and epoch % save_interval == 0: if not os.path.exists(save_path): os.makedirs(save_path) torch.save(dendrite_layer.state_dict(), save_path + "dendrite_" + str(epoch)) if context_model: torch.save(context_model.state_dict(), save_path + "context_" + str(epoch)) if plot: losses = np.array(losses) plt.scatter(x=np.arange(1, num_training_epochs + 1), y=losses) plt.savefig("training_curve.png")
def learn_to_route(mode, dim_in, dim_out, num_contexts, dim_context, dendrite_module, batch_size=64, num_training_epochs=5000): """ Trains a dendrite layer to match an arbitrary routing function :param mode: must be one of ("dendrites", "all") "dendrites" -> learn only dendrite weights while setting feed-forward weights to those of the routing function "all" -> learn both feed-forward and dendrite weights :param dim_in: the number of dimensions in the input to the routing function and test module :param dim_out: the number of dimensions in the sparse linear output of the routing function and test layer :param num_contexts: the number of unique random binary vectors in the routing function that can "route" the sparse linear output, and also the number of unique context vectors :param dim_context: the number of dimensions in the context vectors :param dendrite_module: a torch.nn.Module subclass that implements a dendrite module in addition to a linear feed-forward module :param batch_size: the batch size during training and evaluation :param num_training_epochs: the number of epochs for which to train the dendrite layer """ r, dendrite_layer, context_vectors, device = init_test_scenario( mode=mode, dim_in=dim_in, dim_out=dim_out, num_contexts=num_contexts, dim_context=dim_context, dendrite_module=dendrite_module) train_dataloader = init_dataloader( routing_function=r, context_vectors=context_vectors, device=device, batch_size=batch_size, x_min=-2.0, x_max=2.0, ) test_dataloader = init_dataloader( routing_function=r, context_vectors=context_vectors, device=device, batch_size=batch_size, x_min=2.0, x_max=6.0, ) optimizer = init_optimizer(mode=mode, layer=dendrite_layer) print("epoch,mean_loss,mean_abs_err") for epoch in range(1, num_training_epochs + 1): # Select L1 weight decay penalty based on scenario if mode == "dendrites": l1_weight_decay = 0.0 elif mode == "all": l1_weight_decay = 1e-6 train_dendrite_model(model=dendrite_layer, loader=train_dataloader, optimizer=optimizer, device=device, criterion=F.l1_loss, concat=False, l1_weight_decay=l1_weight_decay) # Validate model results = evaluate_dendrite_model(model=dendrite_layer, loader=test_dataloader, device=device, criterion=F.l1_loss, concat=False) print("{},{}".format(epoch, results["mean_abs_err"]))