def evaluate(network, iter, scorers=(), out_name='', targets_name='targets', mask_name=None): """Evaluate one or more scores for a network. This tool can be used to evaluate scores of a trained network on test data. Args: network (brainstorm.structure.Network): Network to be evaluated. iter (brainstorm.DataIterator): A data iterator which produces the data on which the scores are computed. scorers (tuple[brainstorm.scorers.Scorer]): A list or tuple of Scorers. out_name (Optional[str]): Name of the network output which is scored against the targets. targets_name (Optional[str]): Name of the targets data provided by the data iterator (``iter``). mask_name (Optional[str]): Name of the mask data provided by the data iterator (``iter``). """ iterator = iter(handler=network.handler) scores = {scorer.__name__: [] for scorer in scorers} for n in network.get_loss_values(): scores[n] = [] for _ in run_network(network, iterator): network.forward_pass() gather_losses_and_scores( network, scorers, scores, out_name=out_name, targets_name=targets_name, mask_name=mask_name) return aggregate_losses_and_scores(scores, network, scorers)
def train(self, net, training_data_iter, **named_data_iters): """ Train a network using a data iterator and further named data iterators. """ if self.verbose: if self.logging_function == print: self.logging_function('\n\n' + 10 * '- ' + "Before Training" + 10 * ' -') else: self.logging_function(10 * '- ' + "Before Training" + 10 * ' -') assert set(training_data_iter.data_shapes.keys()) == set( net.buffer.Input.outputs.keys()), \ "The data names provided by the training data iterator {} do not "\ "map to the network input names {}".format( training_data_iter.data_shapes.keys(), net.buffer.Input.outputs.keys()) self.stepper.start(net) named_data_iters['training_data_iter'] = training_data_iter self._start_hooks(net, named_data_iters) if self._emit_hooks(net, 'update') or self._emit_hooks(net, 'epoch'): return should_stop = False while not should_stop: self.current_epoch_nr += 1 sys.stdout.flush() train_scores = {s.__name__: [] for s in self.train_scorers} train_scores.update({n: [] for n in net.get_loss_values()}) if self.verbose: if self.logging_function == print: self.logging_function('\n\n' + 12 * '- ' + "Epoch" + str(self.current_epoch_nr) + 12 * ' -') else: self.logging_function(12 * '- ' + "Epoch" + str(self.current_epoch_nr) + 12 * ' -') iterator = training_data_iter(handler=net.handler) for _ in run_network(net, iterator): self.current_update_nr += 1 self.stepper.run() gather_losses_and_scores(net, self.train_scorers, train_scores) net.apply_weight_modifiers() if self._emit_hooks(net, 'update'): should_stop = True break self._add_log('rolling_training', aggregate_losses_and_scores(train_scores, net, self.train_scorers)) should_stop |= self._emit_hooks(net, 'epoch')