def test_process_final(self): my_mock = Metric('test') my_mock.process_final = Mock(return_value={'test': -1}) metric = MetricList([my_mock]) result = metric.process_final({'state': -1}) self.assertEqual({'test': -1}, result) my_mock.process_final.assert_called_once_with({'state': -1})
def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2): if criterion is None: def criterion(_, __): return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True) self.verbose = verbose self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) self.state = State() self.state.update({ torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(), torchbearer.METRIC_LIST: MetricList(metrics), torchbearer.CALLBACK_LIST: CallbackList(callbacks), torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32, torchbearer.SELF: self, torchbearer.HISTORY: [], torchbearer.BACKWARD_ARGS: {}, torchbearer.TRAIN_GENERATOR: None, torchbearer.VALIDATION_GENERATOR: None, torchbearer.TEST_GENERATOR: None, torchbearer.TRAIN_STEPS: None, torchbearer.VALIDATION_STEPS: None, torchbearer.TEST_STEPS: None, torchbearer.TRAIN_DATA: None, torchbearer.VALIDATION_DATA: None, torchbearer.TEST_DATA: None, torchbearer.INF_TRAIN_LOADING: False, torchbearer.LOADER: None }) self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False): if criterion is None: def criterion(_, y_true): return torch.zeros(1, device=y_true.device) self.pass_state = pass_state self.state = State() self.state.update({ torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(), torchbearer.METRIC_LIST: MetricList(metrics), torchbearer.CALLBACK_LIST: CallbackList(callbacks), torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32, torchbearer.SELF: self, torchbearer.HISTORY: [], torchbearer.BACKWARD_ARGS: {}, torchbearer.TRAIN_GENERATOR: None, torchbearer.VALIDATION_GENERATOR: None, torchbearer.TEST_GENERATOR: None, torchbearer.TRAIN_STEPS: None, torchbearer.VALIDATION_STEPS: None, torchbearer.TEST_STEPS: None })
def _add_metric(self, state): from pycm import ConfusionMatrix def make_cm(y_pred, y_true): _, y_pred = torch.max(y_pred, 1) cm = ConfusionMatrix(y_true.cpu().numpy(), y_pred.cpu().numpy(), **self.kwargs) for handler in self._handlers: handler(cm, state) my_metric = EpochLambda('pycm', make_cm, False) my_metric.reset(state) state[torchbearer.METRIC_LIST] = MetricList( [state[torchbearer.METRIC_LIST], my_metric])
def test_default_loss(self): metric = MetricList(['loss']) self.assertTrue(metric.metric_list[0].name == 'loss', msg='loss not in: ' + str(metric.metric_list))
def test_default_acc(self): metric = MetricList(['acc']) self.assertTrue(metric.metric_list[0].name == 'acc', msg='acc not in: ' + str(metric.metric_list))
def test_list_in_list(self): metric = MetricList(['acc', MetricList(['loss'])]) self.assertTrue(metric.metric_list[0].name == 'acc') self.assertTrue(metric.metric_list[1].name == 'loss')
def test_reset(self): my_mock = Metric('test') my_mock.reset = Mock(return_value=None) metric = MetricList([my_mock]) metric.reset({'state': -1}) my_mock.reset.assert_called_once_with({'state': -1})
def test_eval(self): my_mock = Metric('test') my_mock.eval = Mock(return_value=None) metric = MetricList([my_mock]) metric.eval() my_mock.eval.assert_called_once()
def test_train(self): my_mock = Metric('test') my_mock.train = Mock(return_value=None) metric = MetricList([my_mock]) metric.train() my_mock.train.assert_called_once()
def test_eval(self): my_mock = Metric('test') my_mock.eval = Mock(return_value=None) metric = MetricList([my_mock]) metric.eval() self.assertEqual(my_mock.eval.call_count, 1)
def test_train(self): my_mock = Metric('test') my_mock.train = Mock(return_value=None) metric = MetricList([my_mock]) metric.train() self.assertEqual(my_mock.train.call_count, 1)
def test_default_roc(self): mlist = MetricList(['roc_auc']) self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score') mlist = MetricList(['roc_auc_score']) self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score')