Exemplo n.º 1
0
def evaluate(network, iter, scorers=(), out_name='', targets_name='targets',
             mask_name=None):
    """Evaluate one or more scores for a network.

    This tool can be used to evaluate scores of a trained network on test
    data.

    Args:
        network (brainstorm.structure.Network): Network to be evaluated.
        iter (brainstorm.DataIterator): A data iterator which produces the
                                        data on which the scores are computed.
        scorers (tuple[brainstorm.scorers.Scorer]): A list or tuple of Scorers.
        out_name (Optional[str]): Name of the network output which is scored
                                  against the targets.
        targets_name (Optional[str]): Name of the targets data provided by the
                                      data iterator (``iter``).
        mask_name (Optional[str]): Name of the mask data  provided by the
                                   data iterator (``iter``).
    """
    iterator = iter(handler=network.handler)
    scores = {scorer.__name__: [] for scorer in scorers}
    for n in network.get_loss_values():
        scores[n] = []

    for _ in run_network(network, iterator):
        network.forward_pass()
        gather_losses_and_scores(
            network, scorers, scores, out_name=out_name,
            targets_name=targets_name, mask_name=mask_name)

    return aggregate_losses_and_scores(scores, network, scorers)
Exemplo n.º 2
0
    def train(self, net, training_data_iter, **named_data_iters):
        """
        Train a network using a data iterator and further named data
        iterators.
        """
        if self.verbose:
            if self.logging_function == print:
                self.logging_function('\n\n' + 10 * '- ' + "Before Training" + 
                                    10 * ' -')
            else:
                self.logging_function(10 * '- ' + "Before Training" + 10 * ' -')
        assert set(training_data_iter.data_shapes.keys()) == set(
            net.buffer.Input.outputs.keys()), \
            "The data names provided by the training data iterator {} do not "\
            "map to the network input names {}".format(
                training_data_iter.data_shapes.keys(),
                net.buffer.Input.outputs.keys())
        self.stepper.start(net)
        named_data_iters['training_data_iter'] = training_data_iter
        self._start_hooks(net, named_data_iters)
        if self._emit_hooks(net, 'update') or self._emit_hooks(net, 'epoch'):
            return

        should_stop = False
        while not should_stop:
            self.current_epoch_nr += 1
            sys.stdout.flush()
            train_scores = {s.__name__: [] for s in self.train_scorers}
            train_scores.update({n: [] for n in net.get_loss_values()})

            if self.verbose:
                if self.logging_function == print:
                    self.logging_function('\n\n' + 12 * '- ' + "Epoch" +
                                        str(self.current_epoch_nr) + 12 * ' -')
                else:
                    self.logging_function(12 * '- ' + "Epoch" +
                                        str(self.current_epoch_nr) + 12 * ' -')
            iterator = training_data_iter(handler=net.handler)
            for _ in run_network(net, iterator):
                self.current_update_nr += 1
                self.stepper.run()
                gather_losses_and_scores(net, self.train_scorers, train_scores)
                net.apply_weight_modifiers()
                if self._emit_hooks(net, 'update'):
                    should_stop = True
                    break

            self._add_log('rolling_training',
                          aggregate_losses_and_scores(train_scores, net,
                                                      self.train_scorers))

            should_stop |= self._emit_hooks(net, 'epoch')