Esempio n. 1
0
    def test_mse_loss(self):
        """Testing class :class:`lib.backprop_functions.MSELossFunction`."""
        # Simple test of the forward method only.
        a = torch.from_numpy(np.array([[2., 3.]]))
        t = torch.from_numpy(np.array([[1., 1.]]))
        target_mse = 0.5 * (1**2 + 2**2)  # 2.5
        our_mse = bf.mse_loss(a, t).detach().numpy()
        self.assertAlmostEqual(our_mse, target_mse, 5,
                               'MSE value not correctly computed.')

        # Compare forward and backward method to PyTorch its implementation.
        mse_ours = bf.mse_loss(self.A2, self.T)
        mse_torch = self._pytorch_mse(self.A2, self.T)
        self.assertAlmostEqual(mse_ours.detach().numpy(),
                               mse_torch.detach().numpy(), 5,
                               'MSE value not correctly computed.')

        # Check if gradient computation is correct (i.e., backward path)
        if self.A2.grad is not None:
            self.A2.grad.zero_()
        mse_ours.backward()
        our_grad_A = self.A2.grad.clone()

        self.A2.grad.zero_()
        mse_torch.backward()
        torch_grad_A = self.A2.grad.clone()

        grad_error = torch.sum((our_grad_A - torch_grad_A)**2).detach().numpy()
        self.assertAlmostEqual(
            grad_error, 0., 5,
            'MSE loss gradients not ' + 'correctly computed.')
Esempio n. 2
0
def train(args, device, train_loader, net):
    """Train the given network on the given (regression) dataset.

    Args:
        args (argparse.Namespace): The command-line arguments.
        device: The PyTorch device to be used.
        train_loader (torch.utils.data.DataLoader): The data handler for
            training data.
        net: The (student) neural network.
    """
    print('Training network ...')
    net.train()

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)

    for e in range(args.epochs):
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            predictions = net.forward(inputs)

            optimizer.zero_grad()

            loss = bf.mse_loss(predictions, targets)
            loss.backward()
            optimizer.step()

        if (e + 1) % 10 == 0:
            print('Epoch {} -- loss = {}.'.format(e + 1, loss))

    print('Training network ... Done')
Esempio n. 3
0
def train(args, device, train_loader, net):
    """Train the given network on the given (regression) dataset.

    Args:
        args (argparse.Namespace): The command-line arguments.
        device: The PyTorch device to be used.
        train_loader (torch.utils.data.DataLoader): The data handler for
            training data.
        net: The (student) neural network.
    """
    print('Training network ...')
    net.train()

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    if args.plot_matrix_angles:
        angles_transpose = torch.empty(args.epochs,
                                       net.depth,
                                       requires_grad=False)
        angles_pinv = torch.empty(args.epochs, net.depth, requires_grad=False)
        angles_contr = torch.empty(args.epochs, net.depth, requires_grad=False)

    for e in range(args.epochs):
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            predictions = net.forward(inputs)

            optimizer.zero_grad()

            loss = bf.mse_loss(predictions, targets)
            loss.backward()
            optimizer.step()

        # Save the feedback alignment angles
        if args.plot_matrix_angles:
            for i in range(net.depth):
                W = net.linear_layers[i + 1].weights
                B = net.linear_layers[i + 1].feedbackweights
                W_pinv = torch.pinverse(W)
                angles_transpose[e, i] = utils.compute_matrix_angle(W.t(), B)
                angles_pinv[e, i] = utils.compute_matrix_angle(W_pinv, B)
                angles_contr[e, i] = utils.compute_matrix_angle(W_pinv, W.t())

        if (e + 1) % 10 == 0:
            print('Epoch {} -- loss = {}.'.format(e + 1, loss))

    print('Training network ... Done')

    if args.plot_matrix_angles:
        print('Plotting matrix angles ...')
        utils.plot_angles(angles_transpose, r'angle between $W_i^T$ and $B$',
                          r'$B_{%s} \angle W_{%s}^{T} [\circ]$')
        utils.plot_angles(angles_pinv,
                          r'angles between $W_i^{\dagger}$ and $B$',
                          r'$B_{%s} \angle W_{%s}^{\dagger} [\circ]$')
Esempio n. 4
0
def test(device, test_loader, net):
    """test a train network by computing the MSE on the test set.

    Args:
        (....): See docstring of function :func:`train`.
        test_loader (torch.utils.data.DataLoader): The data handler for
            test data.

    Returns:
        (float): The mean-squared error for the test set ``test_loader`` when
        using the network ``net``. Note, the ``Function``
        :func:`lib.backprop_functions.mse_loss` is used to compute the MSE
        value.
    """
    #######################################################################
    ### NOTE, the function `mse_loss` divides by the current batch size. In
    ### order to compute the MSE across several mini-batches, one needs to
    ### correct for this behavior.
    #######################################################################

    net.eval()

    with torch.no_grad():
        num_samples = 0
        mse = 0.

        for i, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            batch_size = int(inputs.shape[0])
            num_samples += batch_size

            predictions = net.forward(inputs)

            mse += batch_size * bf.mse_loss(predictions, targets)

        mse /= num_samples

    print('Test MSE: {}'.format(mse))

    return float(mse.cpu().detach().numpy())
Esempio n. 5
0
    def test_mse_loss_targets(self):
        """Testing if gradients of targets in class
        :class:`lib.backprop_functions.MSELossFunction` are correctly
        implemented.
        """
        # Ensure that the gradients with respect to the targets are also
        # correctly computed.
        mse_ours = bf.mse_loss(self.T, self.A2)
        mse_torch = self._pytorch_mse(self.T, self.A2)

        if self.A2.grad is not None:
            self.A2.grad.zero_()
        mse_ours.backward()
        our_grad_A = self.A2.grad.clone()

        self.A2.grad.zero_()
        mse_torch.backward()
        torch_grad_A = self.A2.grad.clone()

        grad_error = torch.sum((our_grad_A - torch_grad_A) ** 2).detach(). \
            numpy()
        self.assertAlmostEqual(
            grad_error, 0., 5,
            'MSE loss gradients for targets not correctly computed.')