Exemple #1
0
    def test_make_cm(self, emock_lambda):
        if sys.version_info[0] >= 3:
            with patch('pycm.ConfusionMatrix') as confusion_mocktrix:
                confusion_mocktrix.return_value = 'test'
                handler = MagicMock()
                callback = PyCM(test=10).with_handler(handler)
                state = {torchbearer.METRIC_LIST: None}

                callback._add_metric(state)
                emock_lambda.assert_called_once_with('pycm', ANY, False)

                make_cm = emock_lambda.call_args[0][1]

                import torch

                y_pred = torch.rand(5, 2) / 2
                y_pred[:, 1] = 1
                y_true = MagicMock()

                make_cm(y_pred, y_true)

                self.assertTrue(y_true.cpu.call_count == 1)
                self.assertTrue(y_true.cpu().numpy.call_count == 1)
                confusion_mocktrix.assert_called_once_with(y_true.cpu().numpy(), ANY, test=10)
                self.assertTrue(confusion_mocktrix.call_args[0][1].sum() == 5)
                handler.assert_called_once_with('test', state)
Exemple #2
0
 def test_to_state(self):
     if sys.version_info[0] >= 3:
         callback = PyCM()
         callback.to_state('test')
         out = {}
         callback._handlers[0]('cm', out)
         self.assertTrue('test' in out)
         self.assertTrue(out['test'] == 'cm')
 def test_on_test(self):
     if sys.version_info[0] >= 3:
         callback = PyCM().on_test()
         state = {
             torchbearer.METRIC_LIST: None,
             torchbearer.DATA: torchbearer.TEST_DATA
         }
         callback.on_start_validation(state)
         self.assertTrue(state[torchbearer.METRIC_LIST] is not None)
Exemple #4
0
def run(hidden_size: int, file_prefix: str, epochs: int = 20):
    # Flatten 28*28 images to a 784 vector for each image
    transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
        transforms.Lambda(lambda x: x.view(-1)),  # flatten into vector
    ])
    trainset = MNIST(".", train=True, download=True, transform=transform)
    testset = MNIST(".", train=False, download=True, transform=transform)
    data_size = torch.numel(trainset[0][0])
    # Create data loaders
    trainloader = DataLoader(trainset,
                             batch_size=128,
                             shuffle=True,
                             drop_last=False)
    testloader = DataLoader(testset,
                            batch_size=128,
                            shuffle=True,
                            drop_last=False)

    model = SingleHiddenLayerMLP(data_size, hidden_size, 10)

    loss_function = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters())

    device = "cuda" if torch.cuda.is_available() else "cpu"
    cm = (PyCM().on_train().with_handler(
        to_pandas_seaborn(normalize=True, title="Confusion Matrix: {epoch}")))
    callbacks = [cm]
    trial = Trial(
        model,
        optimiser,
        loss_function,
        metrics=["loss", "accuracy"],
        callbacks=callbacks,
    )
    trial.to(device)
    trial.with_generators(trainloader, val_generator=testloader, val_steps=1)
    history = trial.run(epochs=epochs)

    return history
Exemple #5
0
 def test_with_handler(self):
     if sys.version_info[0] >= 3:
         callback = PyCM()
         callback.with_handler('test')
         self.assertTrue('test' in callback._handlers)
Exemple #6
0
 def test_on_train(self):
     if sys.version_info[0] >= 3:
         callback = PyCM().on_train()
         state = {torchbearer.METRIC_LIST: None}
         callback.on_start_training(state)
         self.assertTrue(state[torchbearer.METRIC_LIST] is not None)
Exemple #7
0
 def test_to_pyplot(self, mock_to_pyplot):
     if sys.version_info[0] >= 3:
         PyCM().to_pyplot(True, 'test', 'test2')
         mock_to_pyplot.assert_called_once_with(normalize=True, title='test', cmap='test2')
Exemple #8
0
    def test_to_file(self):
        if sys.version_info[0] >= 3:
            callback = PyCM()
            callback.to_pycm_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 1})

            cm.save_stat.assert_called_once_with('test 1')

            callback = PyCM()
            callback.to_html_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 2})

            cm.save_html.assert_called_once_with('test 2')

            callback = PyCM()
            callback.to_csv_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 3})

            cm.save_csv.assert_called_once_with('test 3')

            callback = PyCM()
            callback.to_obj_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 4})

            cm.save_obj.assert_called_once_with('test 4')
Exemple #9
0
 def test_to_console(self, mock_print):
     if sys.version_info[0] >= 3:
         callback = PyCM()
         callback.to_console()
         callback._handlers[0]('cm', {})
         mock_print.assert_called_once_with('cm')
    def test_to_file(self):
        if sys.version_info[0] >= 3:
            callback = PyCM()
            callback.to_pycm_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 1})

            cm.save_stat.assert_called_once_with('test 1',
                                                 address=True,
                                                 overall_param=None,
                                                 class_param=None,
                                                 class_name=None)

            callback = PyCM()
            callback.to_html_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 2})

            cm.save_html.assert_called_once_with('test 2',
                                                 address=True,
                                                 overall_param=None,
                                                 class_param=None,
                                                 class_name=None,
                                                 color=(0, 0, 0),
                                                 normalize=False)

            callback = PyCM()
            callback.to_csv_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 3})

            cm.save_csv.assert_called_once_with('test 3',
                                                address=True,
                                                overall_param=None,
                                                class_param=None,
                                                class_name=None,
                                                matrix_save=True,
                                                normalize=False)

            callback = PyCM()
            callback.to_obj_file('test {epoch}')
            cm = MagicMock()
            callback._handlers[0](cm, {torchbearer.EPOCH: 4})

            cm.save_obj.assert_called_once_with('test 4',
                                                address=True,
                                                save_stat=False,
                                                save_vector=True)