Beispiel #1
0
def test_mean_aggregator():
    num_examples = 4
    batch_size = 2

    features = numpy.array([[0, 3], [2, 9], [2, 4], [5, 1]],
                           dtype=theano.config.floatX)

    dataset = IndexableDataset(OrderedDict([('features', features)]))

    data_stream = DataStream(dataset,
                             iteration_scheme=SequentialScheme(
                                 num_examples, batch_size))

    x = tensor.matrix('features')
    y = (x**2).mean(axis=0)
    y.name = 'y'
    z = y.sum()
    z.name = 'z'

    y.tag.aggregation_scheme = Mean(y, 1.)
    z.tag.aggregation_scheme = Mean(z, 1.)

    assert_allclose(
        DatasetEvaluator([y]).evaluate(data_stream)['y'],
        numpy.array([8.25, 26.75], dtype=theano.config.floatX))
    assert_allclose(
        DatasetEvaluator([z]).evaluate(data_stream)['z'],
        numpy.array([35], dtype=theano.config.floatX))
    def _create_aggregators(self):
        """Create aggregators and collect updates."""
        self.initialization_updates = []
        self.accumulation_updates = []
        self.readout_variables = OrderedDict()

        for v in self.variables:
            logger.debug('variable to evaluate: %s', v.name)
            if not hasattr(v.tag, 'aggregation_scheme'):
                if not self._computation_graph.has_inputs(v):
                    scheme = (TakeLast
                              if self.use_take_last else _DataIndependent)
                    logger.debug(
                        'Using %s aggregation scheme'
                        ' for %s since it does not depend on'
                        ' the data', scheme.__name__, v.name)
                    v.tag.aggregation_scheme = scheme(v)
                else:
                    logger.debug(
                        'Using the default '
                        ' (average over minibatches)'
                        ' aggregation scheme for %s', v.name)
                    v.tag.aggregation_scheme = Mean(v, 1.0)

            aggregator = v.tag.aggregation_scheme.get_aggregator()
            self.initialization_updates.extend(
                aggregator.initialization_updates)
            self.accumulation_updates.extend(aggregator.accumulation_updates)
            self.readout_variables[v.name] = aggregator.readout_variable