示例#1
0
 def test_process_final(self):
     my_mock = Metric('test')
     my_mock.process_final = Mock(return_value={'test': -1})
     metric = MetricList([my_mock])
     result = metric.process_final({'state': -1})
     self.assertEqual({'test': -1}, result)
     my_mock.process_final.assert_called_once_with({'state': -1})
示例#2
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2):
        if criterion is None:
            def criterion(_, __):
                return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True)

        self.verbose = verbose

        self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER)
        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None,
            torchbearer.TRAIN_DATA: None,
            torchbearer.VALIDATION_DATA: None,
            torchbearer.TEST_DATA: None,
            torchbearer.INF_TRAIN_LOADING: False,
            torchbearer.LOADER: None
        })

        self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
示例#3
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False):
        if criterion is None:
            def criterion(_, y_true):
                return torch.zeros(1, device=y_true.device)

        self.pass_state = pass_state

        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None
        })
示例#4
0
    def _add_metric(self, state):
        from pycm import ConfusionMatrix

        def make_cm(y_pred, y_true):
            _, y_pred = torch.max(y_pred, 1)
            cm = ConfusionMatrix(y_true.cpu().numpy(),
                                 y_pred.cpu().numpy(), **self.kwargs)
            for handler in self._handlers:
                handler(cm, state)

        my_metric = EpochLambda('pycm', make_cm, False)
        my_metric.reset(state)
        state[torchbearer.METRIC_LIST] = MetricList(
            [state[torchbearer.METRIC_LIST], my_metric])
示例#5
0
 def test_default_loss(self):
     metric = MetricList(['loss'])
     self.assertTrue(metric.metric_list[0].name == 'loss', msg='loss not in: ' + str(metric.metric_list))
示例#6
0
 def test_default_acc(self):
     metric = MetricList(['acc'])
     self.assertTrue(metric.metric_list[0].name == 'acc', msg='acc not in: ' + str(metric.metric_list))
示例#7
0
 def test_list_in_list(self):
     metric = MetricList(['acc', MetricList(['loss'])])
     self.assertTrue(metric.metric_list[0].name == 'acc')
     self.assertTrue(metric.metric_list[1].name == 'loss')
示例#8
0
 def test_reset(self):
     my_mock = Metric('test')
     my_mock.reset = Mock(return_value=None)
     metric = MetricList([my_mock])
     metric.reset({'state': -1})
     my_mock.reset.assert_called_once_with({'state': -1})
示例#9
0
 def test_eval(self):
     my_mock = Metric('test')
     my_mock.eval = Mock(return_value=None)
     metric = MetricList([my_mock])
     metric.eval()
     my_mock.eval.assert_called_once()
示例#10
0
 def test_train(self):
     my_mock = Metric('test')
     my_mock.train = Mock(return_value=None)
     metric = MetricList([my_mock])
     metric.train()
     my_mock.train.assert_called_once()
示例#11
0
 def test_eval(self):
     my_mock = Metric('test')
     my_mock.eval = Mock(return_value=None)
     metric = MetricList([my_mock])
     metric.eval()
     self.assertEqual(my_mock.eval.call_count, 1)
示例#12
0
 def test_train(self):
     my_mock = Metric('test')
     my_mock.train = Mock(return_value=None)
     metric = MetricList([my_mock])
     metric.train()
     self.assertEqual(my_mock.train.call_count, 1)
示例#13
0
    def test_default_roc(self):
        mlist = MetricList(['roc_auc'])
        self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score')

        mlist = MetricList(['roc_auc_score'])
        self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score')