Example #1
0
    def __call__(self, epoch_nr, update_nr, net, stepper, logs):
        if epoch_nr == 0:
            try:
                e = get_by_path(logs, self.log_name)
            except KeyError:
                return

        e = get_by_path(logs, self.log_name)
        last = e[-1]
        if self.criterion == 'min':
            imp = last < self.best_so_far
        else:
            imp = last > self.best_so_far

        if imp:
            self.best_so_far = last
            self.best_t = epoch_nr if self.timescale == 'epoch' else update_nr
            params = net.get('parameters')
            if self.filename is not None:
                self.message("{} improved (criterion: {}). Saving network to "
                             "{}".format(self.log_name, self.criterion,
                                         self.filename))
                net.save_as_hdf5(self.filename)
            else:
                self.message("{} improved (criterion: {}). Caching parameters".
                             format(self.log_name, self.criterion))
                self.parameters = params
        else:
            self.message("Last saved parameters at {} {} when {} was {}".
                         format(self.timescale, self.best_t, self.log_name,
                                self.best_so_far))
Example #2
0
    def __call__(self, epoch_nr, update_nr, net, stepper, logs):
        if epoch_nr == 0:
            try:
                e = get_by_path(logs, self.log_name)
            except KeyError:
                return

        e = get_by_path(logs, self.log_name)
        last = e[-1]
        if self.criterion == 'min':
            imp = last < self.best_so_far
        else:
            imp = last > self.best_so_far

        if imp:
            self.best_so_far = last
            self.best_t = epoch_nr if self.timescale == 'epoch' else update_nr
            params = net.get('parameters')
            if self.filename is not None:
                self.message("{} improved. Saving network to {} ...".
                             format(self.log_name, self.filename))
                net.save_as_hdf5(self.filename)
            else:
                self.message("{} improved. Caching parameters ...".
                             format(self.log_name))
                self.parameters = params
        else:
            self.message("Last saved parameters at {} {} when {} was {}".
                         format(self.timescale, self.best_t, self.log_name,
                                self.best_so_far))
Example #3
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     if epoch_nr == 0:
         try:
             e = get_by_path(logs, self.log_name)
         except KeyError:
             return
     e = get_by_path(logs, self.log_name)
     best_idx = np.argmin(e) if self.criterion == 'min' else np.argmax(e)
     if len(e) > best_idx + self.patience:
         self.message("Stopping because {} did not improve for {} checks.".
                      format(self.log_name, self.patience))
         raise StopIteration()
Example #4
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     if epoch_nr == 0:
         try:
             e = get_by_path(logs, self.log_name)
         except KeyError:
             return
     e = get_by_path(logs, self.log_name)
     best_idx = np.argmin(e) if self.criterion == 'min' else np.argmax(e)
     if len(e) > best_idx + self.patience:
         self.message("Stopping because {} did not improve for {} checks "
                      "(criterion used : {}).".format(
                          self.log_name, self.patience, self.criterion))
         raise StopIteration()
Example #5
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     e = get_by_path(logs, self.log_name)
     best_error_idx = np.argmin(e)
     if len(e) > best_error_idx + self.patience:
         self.message("Stopping because {} did not decrease for {} epochs.".
                      format(self.log_name, self.patience))
         raise StopIteration()
Example #6
0
    def create(source_set, sink_set, layout, connections):
        def ensure_uniform(l):
            assert min(l) == max(l)
            return l[0]

        sorted_sources = sorted(source_set)
        flat_sources = list(flatten(sorted_sources))
        nesting = convert_to_nested_indices(sorted_sources)

        # get buffer type for hub and assert its uniform
        structs = [
            BufferStructure.from_layout(get_by_path(layout, s))
            for s in flat_sources
        ]
        btype = ensure_uniform([s.buffer_type for s in structs])
        # max context size
        context_size = max([s.context_size for s in structs])

        hub = Hub(flat_sources, nesting, sorted(sink_set), btype, context_size)
        hub.setup(connections)
        hub.sizes = [structs[i].feature_size for i in hub.perm]
        hub.size = sum(hub.sizes)
        hub.is_backward_only = ensure_uniform(
            [structs[i].is_backward_only for i in hub.perm])
        return hub
Example #7
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     e = get_by_path(logs, self.log_name)
     best_error_idx = np.argmin(e)
     if len(e) > best_error_idx + self.patience:
         self.message(
             "Stopping because {} did not decrease for {} epochs.".format(
                 self.log_name, self.patience))
         raise StopIteration()
Example #8
0
def layout_hubs(hubs, layout):
    """
    Determine and fill in the @slice entries into the layout and return total
    buffer sizes.
    """
    for hub_nr, hub in enumerate(hubs):
        for buffer_name, _slice in hub.get_indices():
            buffer_layout = get_by_path(layout, buffer_name)
            buffer_layout['@slice'] = _slice
            buffer_layout['@hub'] = hub_nr
Example #9
0
def layout_hubs(hubs, layout):
    """
    Determine and fill in the @slice entries into the layout and return total
    buffer sizes.
    """
    for hub_nr, hub in enumerate(hubs):
        for buffer_name, _slice in hub.get_indices():
            buffer_layout = get_by_path(layout, buffer_name)
            buffer_layout['@slice'] = _slice
            buffer_layout['@hub'] = hub_nr
Example #10
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     e = get_by_path(logs, self.log_name)
     if self.criterion == 'min':
         best_error_idx = np.argmin(e)
     else:  # self.criterion == 'max'
         best_error_idx = np.argmax(e)
     if len(e) > best_error_idx + self.patience:
         self.message("Stopping because {} did not improve for {} epochs.".
                      format(self.log_name, self.patience))
         raise StopIteration()
Example #11
0
    def __getitem__(self, item):
        if isinstance(item, int):
            return super(BufferView, self).__getitem__(item)
        if item in self._keys:
            return self.__dict__[item]
        elif '.' in item:
            return get_by_path(self, item)

        raise KeyError('{} is not present. Available items are [{}]'
                       .format(item, ", ".join(sorted(self._keys))))
Example #12
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     e = get_by_path(logs, self.log_name)
     is_threshold_reached = False
     if self.criterion == 'max' and max(e) >= self.threshold:
         is_threshold_reached = True
     elif self.criterion == 'min' and min(e) <= self.threshold:
         is_threshold_reached = True
     if is_threshold_reached:
         self.message("Stopping because {} has reached the threshold {} "
                      "(criterion used : {})"
                      .format(self.log_name, self.threshold, self.criterion))
         raise StopIteration()
Example #13
0
 def __call__(self, epoch_nr, update_nr, net, stepper, logs):
     e = get_by_path(logs, self.log_name)
     is_threshold_reached = False
     if self.criterion == 'max' and max(e) >= self.threshold:
         is_threshold_reached = True
     elif self.criterion == 'min' and min(e) <= self.threshold:
         is_threshold_reached = True
     if is_threshold_reached:
         self.message("Stopping because {} has reached the threshold {} "
                      "(criterion used : {})".format(
                          self.log_name, self.threshold, self.criterion))
         raise StopIteration()
Example #14
0
 def get_shape(self, path):
     category, _, subpath = path.partition('.')
     categories = {'parameters', 'inputs', 'outputs', 'internals'}
     if category not in categories:
         raise ValueError("Category '{}' for path '{}' not found. Choices "
                          "are {}".format(category, path, categories))
     category_shapes = {
         'parameters': self.parameter_shapes,
         'internals': self.internal_shapes,
         'inputs': self.in_shapes,
         'outputs': self.out_shapes
     }
     return get_by_path(category_shapes[category], subpath)
Example #15
0
 def get_shape(self, path):
     category, _, subpath = path.partition('.')
     categories = {'parameters', 'inputs', 'outputs', 'internals'}
     if category not in categories:
         raise ValueError("Category '{}' for path '{}' not found. Choices "
                          "are {}".format(category, path, categories))
     category_shapes = {
         'parameters': self.parameter_shapes,
         'internals': self.internal_shapes,
         'inputs': self.in_shapes,
         'outputs': self.out_shapes
     }
     return get_by_path(category_shapes[category], subpath)
Example #16
0
        def __call__(self, epoch_nr, update_nr, net, stepper, logs):
            if epoch_nr == 0:
                return
            for log_name in self.log_names:
                renderer = self.fig.select(dict(name=log_name))

                datasource = renderer[0].data_source
                datasource.data["y"] = get_by_path(logs, log_name)

                datasource.data["x"] = range(len(datasource.data["y"]))
                self.bk.cursession().store_objects(datasource)

            if self.filename is not None:
                self.bk.save(self.fig, filename=self.filename + ".html")
Example #17
0
        def __call__(self, epoch_nr, update_nr, net, stepper, logs):
            if epoch_nr == 0:
                return
            for log_name in self.log_names:
                renderer = self.fig.select(dict(name=log_name))

                datasource = renderer[0].data_source
                datasource.data["y"] = get_by_path(logs, log_name)

                datasource.data["x"] = range(len(datasource.data["y"]))
                self.bk.cursession().store_objects(datasource)

            if self.filename is not None:
                self.bk.save(self.fig, filename=self.filename + ".html")
Example #18
0
    def __call__(self, epoch_nr, update_nr, net, stepper, logs):
        for log_name in self.logs_to_check:
            log = get_by_path(logs, log_name)
            if not np.all(np.isfinite(log)):
                self.message("NaN or inf detected in {}!".format(log_name))
                raise StopIteration()
        if self.check_parameters:
            if not net.handler.is_fully_finite(net.buffer.parameters):
                self.message("NaN or inf detected in parameters!")
                raise StopIteration()

        if self.check_training_loss and 'rolling_training' in logs:
            rtrain = logs['rolling_training']
            if 'total_loss' in rtrain:
                loss = rtrain['total_loss']
            else:
                loss = rtrain['Loss']
            if not np.all(np.isfinite(loss)):
                self.message("NaN or inf detected in rolling training loss!")
                raise StopIteration()
Example #19
0
    def __call__(self, epoch_nr, update_nr, net, stepper, logs):
        for log_name in self.logs_to_check:
            log = get_by_path(logs, log_name)
            if not np.all(np.isfinite(log)):
                self.message("NaN or inf detected in {}!".format(log_name))
                raise StopIteration()
        if self.check_parameters:
            if not net.handler.is_fully_finite(net.buffer.parameters):
                self.message("NaN or inf detected in parameters!")
                raise StopIteration()

        if self.check_training_loss and 'rolling_training' in logs:
            rtrain = logs['rolling_training']
            if 'total_loss' in rtrain:
                loss = rtrain['total_loss']
            else:
                loss = rtrain['Loss']
            if not np.all(np.isfinite(loss)):
                self.message("NaN or inf detected in rolling training loss!")
                raise StopIteration()
Example #20
0
    def create(source_set, sink_set, layout, connections):
        def ensure_uniform(l):
            assert min(l) == max(l)
            return l[0]

        sorted_sources = sorted(source_set)
        flat_sources = list(flatten(sorted_sources))
        nesting = convert_to_nested_indices(sorted_sources)

        # get buffer type for hub and assert its uniform
        structs = [BufferStructure.from_layout(get_by_path(layout, s)) for s in flat_sources]
        btype = ensure_uniform([s.buffer_type for s in structs])
        # max context size
        context_size = max([s.context_size for s in structs])

        hub = Hub(flat_sources, nesting, sorted(sink_set), btype, context_size)
        hub.setup(connections)
        hub.sizes = [structs[i].feature_size for i in hub.perm]
        hub.size = sum(hub.sizes)
        hub.is_backward_only = ensure_uniform([structs[i].is_backward_only for i in hub.perm])
        return hub