def test_forward_pass(self):
     # Really just testing that the forward pass doesn't blow up
     mlp = DistributedMLP(10, (3, 3), 2)
     x = torch.randn(1, 10)
     result = mlp.forward(x)
     print(f'The result of the forward pass is {result}')
     assert result is not None
Exemplo n.º 2
0
    def test_test_model(self):
        mlp = DistributedMLP(2, (3, 3), 1)
        optimizer = SGD(mlp.parameters(), lr=0.001, momentum=0.9)
        loss_function = torch.nn.L1Loss()
        dataset = TensorDataset(self.X, self.y_1)
        data_loader = DataLoader(dataset)

        test_dataset = TensorDataset(self.X_test, self.y_test_1)
        test_data_loader = DataLoader(test_dataset)

        train(mlp, 100, data_loader, optimizer, loss_function)
        test_loss, accuracy = test(mlp,  test_data_loader, loss_function)
        assert test_loss is not None
        assert accuracy is not None
Exemplo n.º 3
0
    def test_train_model_to_predict_1(self):
        mlp = DistributedMLP(2, (3, 3), 1)

        dataset = TensorDataset(self.X, self.y_1)  # create your datset
        batch_size = 3
        data_loader = DataLoader(dataset, batch_size)

        optimizer = SGD(mlp.parameters(), lr=0.001, momentum=0.9)
        loss_function = torch.nn.L1Loss()

        train(mlp, 100, data_loader, optimizer, loss_function)

        y_pred = mlp(self.X_test)
        assert y_pred is not None
        for y in y_pred:
            assert 1 - y < 1e-2
 def test_weights_are_initialised_with_std_kaiming_init(self):
     mlp = DistributedMLP(10, (5, 5), 2)
     for module in list(mlp._modules.items()):
         standard_deviation = module[1].weight.std()
         expected = self.calc_expected_standard_deviation(module[1].weight.size(1))
         print(f'Actual Standard deviation {standard_deviation} | Expected {expected} ')
         assert abs(standard_deviation - expected) < 1e4
 def test_weights_are_initialised_with_mean_close_to_zero(self):
     mlp = DistributedMLP(10, (5, 5), 2)
     for module in list(mlp._modules.items()):
         mean = module[1].weight.mean()
         print(f'The mean is {mean}')
         assert module[1].weight.mean().abs() < 12e-2
 def test_model_is_created_correctly(self):
     mlp = DistributedMLP(10, (5, 5), 2)
     assert mlp is not None