Exemple #1
0
 def __init__(self, data_stream, probs, labels, **kwargs):
     self.data_stream = data_stream
     self.evaluator = DataStreamEvaluator(
         [probs.copy('probs'), labels.copy('targets')])
     logger.info("compiling auc logger")
     self.evaluator._compile()
     super(AUCMonitor, self).__init__(**kwargs)
Exemple #2
0
 def __init__(self, data_stream, prediction, targets, label_i_to_c,
              **kwargs):
     self.data_stream = data_stream
     self.label_i_to_c = label_i_to_c
     self.evaluator = DataStreamEvaluator(
         [prediction.copy('prediction'),
          targets.copy('targets')])
     if prediction.ndim != 1 or targets.ndim != 1:
         raise ValueError("targets and predictions must be 1d vectors")
     logger.info("compiling perclass accuracy logger")
     self.evaluator._compile()
     super(PerClassAccuracyMonitor, self).__init__(**kwargs)
Exemple #3
0
def test_datastream_evaluator():
    stream = IndexableDataset(indexables=OrderedDict([
        ("data", np.ones((10, 4, 9), dtype="float32")),
    ])).get_example_stream()

    x = T.matrix("data")
    mon = x.sum(axis=1)
    mon.name = "mon"

    evaluator = DataStreamEvaluator([mon])
    results = evaluator.evaluate(stream)
    assert set(results.keys()) == set(['mon'])

    assert_allclose(results['mon'], np.ones((4 * 10)) * 9)
Exemple #4
0
    def __init__(self, data_stream, prediction, targets, dest_directory,
                 **kwargs):
        self.data_stream = data_stream
        self.evaluator = DataStreamEvaluator(
            [prediction.copy('prediction'),
             targets.copy('targets')])
        self.dest_directory = dest_directory

        if not os.path.exists(self.dest_directory):
            os.mkdir(self.dest_directory)

        if prediction.ndim != 1 or targets.ndim != 1:
            raise ValueError("targets and predictions must be 1d vectors")
        logger.info("compiling confusion matrix monitor")
        self.evaluator._compile()
        super(ConfusionMatrixMonitor, self).__init__(**kwargs)