def test_non_zero_parameter(self):
        """
        The updated parameters should match the hand-derived solution
        """
        weight = torch.Tensor([
            [-0.4884, -0.2566, 0.3548],
            [0.2883, -0.5463, 0.0184],
            [0.2392, 0.0000, 0.3523]
        ])
        weight = torch.nn.Parameter(weight)

        bias = torch.Tensor([-0.4757, -0.4825, -0.0000])
        bias = torch.nn.Parameter(bias)

        l1_regularization_step(params=[weight, bias], lr=0.15, weight_decay=0.25)

        expected_weight = torch.Tensor([
            [-0.4509, -0.2191, 0.3173],
            [0.2508, -0.5088, -0.0191],
            [0.2017, 0.0000, 0.3148]
        ])

        expected_bias = torch.Tensor([-0.4382, -0.4450, 0.0000])

        self.assertTrue(torch.allclose(weight, expected_weight))
        self.assertTrue(torch.allclose(bias, expected_bias))
Exemplo n.º 2
0
def train_dendrite_model(model,
                         loader,
                         optimizer,
                         device,
                         criterion,
                         concat=False,
                         l1_weight_decay=0.0):
    """
    Trains a regular network model by iterating through all batches in the given
    dataloader

    :param model: a torch.nn.Module subclass that implements a dendrite module in
                  addition to a linear feed-forward module, and takes both feedforward
                  and context inputs to its `forward` method
    :param loader: a torch dataloader that iterates over all train and test batches
    :param optimizer: optimizer object used to train the model
    :param device: device to use ('cpu' or 'cuda')
    :param criterion: loss function to minimize
    :param concat: if True, assumes input and context vectors are concatenated together
                   and model takes just a single input to its `forward`, otherwise
                   assumes input and context vectors are separate and model's `forward`
                   function takes a regular input and contextual input separately
    :param l1_weight_decay: L1 regularization coefficient
    """
    model.train()

    for item in loader:

        optimizer.zero_grad()
        if concat:
            data, target = item

            data = data.to(device)
            target = target.to(device)

            output = model(data)

        else:
            data, context, target = item

            data = data.to(device)
            context = context.to(device)
            target = target.to(device)

            output = model(data, context)

        loss = criterion(output, target)

        loss.backward()
        optimizer.step()

        # Perform L1 weight decay
        if l1_weight_decay > 0.0:
            l1_regularization_step(params=model.parameters(),
                                   lr=optimizer.param_groups[0]["lr"],
                                   weight_decay=l1_weight_decay)
    def test_all_zeros(self):
        """
        The zero vector should not be modified
        """
        weight = torch.zeros((10, 10))
        weight = torch.nn.Parameter(weight)

        bias = torch.zeros((10,))
        bias = torch.nn.Parameter(bias)

        l1_regularization_step(params=[weight, bias], lr=0.1, weight_decay=0.1)

        self.assertTrue((weight == 0.0).all().item())
        self.assertTrue((bias == 0.0).all().item())
    def test_zero_weight_decay(self):
        """
        No parameters should not be modified if `weight_decay` is set to zero
        """
        weight = torch.randn((7, 7))
        weight = torch.nn.Parameter(weight)

        bias = torch.randn((7,))
        bias = torch.nn.Parameter(bias)

        # Make copy of original parameters before update
        weight_original = copy.deepcopy(weight)
        bias_original = copy.deepcopy(bias)

        l1_regularization_step(params=[weight, bias], lr=0.1, weight_decay=0.0)

        self.assertTrue(torch.allclose(weight, weight_original))
        self.assertTrue(torch.allclose(bias, bias_original))
    def test_requires_grad(self):
        """
        Any parameters whose `requires_grad` attribute is False should not be modified
        """
        weight = torch.randn((3, 11))
        weight = torch.nn.Parameter(weight)

        bias = torch.randn((11,))
        bias = torch.nn.Parameter(bias, requires_grad=False)

        # Make copy of original parameters before update
        weight_original = copy.deepcopy(weight)
        bias_original = copy.deepcopy(bias)

        l1_regularization_step(params=[weight, bias], lr=0.1, weight_decay=0.1)

        # Here, we assert that at least 1 weight has changed and the bias remains fixed
        self.assertFalse(torch.allclose(weight, weight_original))
        self.assertTrue(torch.allclose(bias, bias_original))