def test_make_cm(self, emock_lambda): if sys.version_info[0] >= 3: with patch('pycm.ConfusionMatrix') as confusion_mocktrix: confusion_mocktrix.return_value = 'test' handler = MagicMock() callback = PyCM(test=10).with_handler(handler) state = {torchbearer.METRIC_LIST: None} callback._add_metric(state) emock_lambda.assert_called_once_with('pycm', ANY, False) make_cm = emock_lambda.call_args[0][1] import torch y_pred = torch.rand(5, 2) / 2 y_pred[:, 1] = 1 y_true = MagicMock() make_cm(y_pred, y_true) self.assertTrue(y_true.cpu.call_count == 1) self.assertTrue(y_true.cpu().numpy.call_count == 1) confusion_mocktrix.assert_called_once_with(y_true.cpu().numpy(), ANY, test=10) self.assertTrue(confusion_mocktrix.call_args[0][1].sum() == 5) handler.assert_called_once_with('test', state)
def test_to_state(self): if sys.version_info[0] >= 3: callback = PyCM() callback.to_state('test') out = {} callback._handlers[0]('cm', out) self.assertTrue('test' in out) self.assertTrue(out['test'] == 'cm')
def test_on_test(self): if sys.version_info[0] >= 3: callback = PyCM().on_test() state = { torchbearer.METRIC_LIST: None, torchbearer.DATA: torchbearer.TEST_DATA } callback.on_start_validation(state) self.assertTrue(state[torchbearer.METRIC_LIST] is not None)
def run(hidden_size: int, file_prefix: str, epochs: int = 20): # Flatten 28*28 images to a 784 vector for each image transform = transforms.Compose([ transforms.ToTensor(), # convert to tensor transforms.Lambda(lambda x: x.view(-1)), # flatten into vector ]) trainset = MNIST(".", train=True, download=True, transform=transform) testset = MNIST(".", train=False, download=True, transform=transform) data_size = torch.numel(trainset[0][0]) # Create data loaders trainloader = DataLoader(trainset, batch_size=128, shuffle=True, drop_last=False) testloader = DataLoader(testset, batch_size=128, shuffle=True, drop_last=False) model = SingleHiddenLayerMLP(data_size, hidden_size, 10) loss_function = nn.CrossEntropyLoss() optimiser = optim.Adam(model.parameters()) device = "cuda" if torch.cuda.is_available() else "cpu" cm = (PyCM().on_train().with_handler( to_pandas_seaborn(normalize=True, title="Confusion Matrix: {epoch}"))) callbacks = [cm] trial = Trial( model, optimiser, loss_function, metrics=["loss", "accuracy"], callbacks=callbacks, ) trial.to(device) trial.with_generators(trainloader, val_generator=testloader, val_steps=1) history = trial.run(epochs=epochs) return history
def test_with_handler(self): if sys.version_info[0] >= 3: callback = PyCM() callback.with_handler('test') self.assertTrue('test' in callback._handlers)
def test_on_train(self): if sys.version_info[0] >= 3: callback = PyCM().on_train() state = {torchbearer.METRIC_LIST: None} callback.on_start_training(state) self.assertTrue(state[torchbearer.METRIC_LIST] is not None)
def test_to_pyplot(self, mock_to_pyplot): if sys.version_info[0] >= 3: PyCM().to_pyplot(True, 'test', 'test2') mock_to_pyplot.assert_called_once_with(normalize=True, title='test', cmap='test2')
def test_to_file(self): if sys.version_info[0] >= 3: callback = PyCM() callback.to_pycm_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 1}) cm.save_stat.assert_called_once_with('test 1') callback = PyCM() callback.to_html_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 2}) cm.save_html.assert_called_once_with('test 2') callback = PyCM() callback.to_csv_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 3}) cm.save_csv.assert_called_once_with('test 3') callback = PyCM() callback.to_obj_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 4}) cm.save_obj.assert_called_once_with('test 4')
def test_to_console(self, mock_print): if sys.version_info[0] >= 3: callback = PyCM() callback.to_console() callback._handlers[0]('cm', {}) mock_print.assert_called_once_with('cm')
def test_to_file(self): if sys.version_info[0] >= 3: callback = PyCM() callback.to_pycm_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 1}) cm.save_stat.assert_called_once_with('test 1', address=True, overall_param=None, class_param=None, class_name=None) callback = PyCM() callback.to_html_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 2}) cm.save_html.assert_called_once_with('test 2', address=True, overall_param=None, class_param=None, class_name=None, color=(0, 0, 0), normalize=False) callback = PyCM() callback.to_csv_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 3}) cm.save_csv.assert_called_once_with('test 3', address=True, overall_param=None, class_param=None, class_name=None, matrix_save=True, normalize=False) callback = PyCM() callback.to_obj_file('test {epoch}') cm = MagicMock() callback._handlers[0](cm, {torchbearer.EPOCH: 4}) cm.save_obj.assert_called_once_with('test 4', address=True, save_stat=False, save_vector=True)