Esempio n. 1
0
    def on_epoch_end(self, net, **kwargs):
        if self.monitor is None:
            do_checkpoint = True
        elif callable(self.monitor):
            do_checkpoint = self.monitor(net)
        else:
            try:
                do_checkpoint = net.history[-1, self.monitor]
            except KeyError as e:
                raise SkorchException(
                    "Monitor value '{}' cannot be found in history. "
                    "Make sure you have validation data if you use "
                    "validation scores for checkpointing.".format(e.args[0]))

        if do_checkpoint:
            target = self.target
            if isinstance(self.target, str):
                target = self.target.format(
                    net=net,
                    last_epoch=net.history[-1],
                    last_batch=net.history[-1, 'batches', -1],
                )
            self._sink("Checkpoint! Saving model to {}.".format(target),
                       net.verbose)
            net.save_params(target)

        net.history.record('event_cp', bool(do_checkpoint))
Esempio n. 2
0
    def on_epoch_end(self, net, **kwargs):
        if "{}_best".format(self.monitor) in net.history[-1]:
            warnings.warn(
                "Checkpoint monitor parameter is set to '{0}' and the history "
                "contains '{0}_best'. Perhaps you meant to set the parameter "
                "to '{0}_best'".format(self.monitor), UserWarning)

        if self.monitor is None:
            do_checkpoint = True
        elif callable(self.monitor):
            do_checkpoint = self.monitor(net)
        else:
            try:
                do_checkpoint = net.history[-1, self.monitor]
            except KeyError as e:
                raise SkorchException(
                    "Monitor value '{}' cannot be found in history. "
                    "Make sure you have validation data if you use "
                    "validation scores for checkpointing.".format(e.args[0]))

        if self.event_name is not None:
            net.history.record(self.event_name, bool(do_checkpoint))

        if do_checkpoint:
            self.save_model(net)
            self._sink(
                "A checkpoint was triggered in epoch {}.".format(
                    len(net.history) + 1), net.verbose)
Esempio n. 3
0
 def on_epoch_end(self, net, **kwargs):
     if self.monitor is None:
         do_checkpoint = True
     elif callable(self.monitor):
         do_checkpoint = self.monitor(net)
     else:
         try:
             do_checkpoint = net.history[-1, self.monitor]
         except KeyError as e:
             raise SkorchException(
                 "Monitor value '{}' cannot be found in history. "
                 "Make sure you have validation data if you use "
                 "validation scores for checkpointing.".format(e.args[0]))
     if do_checkpoint:
         target = self.target
         if isinstance(self.target, str):
             target = self.target.format(net=net,
                                         last_epoch=net.history[-1],
                                         last_batch=net.history[-1,
                                                                'batches',
                                                                -1],
                                         type=self.monitor)
         if net.verbose > 0:
             print("Checkpoint! : {}.".format(target))
         self.save_function(target)
Esempio n. 4
0
 def _validate_filenames(self):
     if not self.dirname:
         return
     if (self.f_optimizer and not isinstance(self.f_optimizer, str)
             or self.f_params and not isinstance(self.f_params, str)
             or self.f_history and not isinstance(self.f_history, str)
             or self.f_pickle and not isinstance(self.f_pickle, str)):
         raise SkorchException(
             'dirname can only be used when f_* are strings')
Esempio n. 5
0
    def _validate_filenames(self):
        """Checks if passed filenames are valid.

        Specifically, f_* parameter should not be passed in
        conjunction with dirname.

        """
        _check_f_arguments(self.__class__.__name__, **self._f_kwargs())

        if not self.dirname:
            return

        def _is_truthy_and_not_str(f):
            return f and not isinstance(f, str)

        if any(_is_truthy_and_not_str(val) for val in self._f_kwargs().values()):
            raise SkorchException(
                'dirname can only be used when f_* are strings')
Esempio n. 6
0
    def _validate_filenames(self):
        """Checks if passed filenames are valid.

        Specifically, f_* parameter should not be passed in
        conjunction with dirname.

        """
        if not self.dirname:
            return

        def _is_truthy_and_not_str(f):
            return f and not isinstance(f, str)

        if (_is_truthy_and_not_str(self.f_optimizer)
                or _is_truthy_and_not_str(self.f_params)
                or _is_truthy_and_not_str(self.f_history)
                or _is_truthy_and_not_str(self.f_pickle)):
            raise SkorchException(
                'dirname can only be used when f_* are strings')
Esempio n. 7
0
    def on_epoch_end(self, net, **kwargs):
        if self.monitor is None:
            do_checkpoint = True
        elif callable(self.monitor):
            do_checkpoint = self.monitor(net)
        else:
            try:
                do_checkpoint = net.history[-1, self.monitor]
            except KeyError as e:
                raise SkorchException(
                    "Monitor value '{}' cannot be found in history. "
                    "Make sure you have validation data if you use "
                    "validation scores for checkpointing.".format(e.args[0]))

        if do_checkpoint:
            self.save_model(net)
            self._sink(
                "A checkpoint was triggered in epoch {}.".format(
                    len(net.history) + 1), net.verbose)

        net.history.record('event_cp', bool(do_checkpoint))
Esempio n. 8
0
    def on_epoch_end(self, net, **kwargs):
        if self.monitor is None:
            do_checkpoint = True
        elif callable(self.monitor):
            do_checkpoint = self.monitor(net)
        else:
            try:
                do_checkpoint = net.history[-1, self.monitor]
            except KeyError as e:
                raise SkorchException(
                    "Monitor value '{}' cannot be found in history. "
                    "Make sure you have validation data if you use "
                    "validation scores for checkpointing.".format(e.args[0]))
        if do_checkpoint:
            model = net.module_
            selected_layer = None
            for name, module in model.named_modules():
                if name == self.layer_name:
                    selected_layer = module
                    break

            # do nothing if layer does not exist
            if not selected_layer:
                return

            handle = selected_layer.register_forward_hook(
                self._forward_layer_hook)
            to_infer_data = self._process_data(self.data, **kwargs)
            input_output = self._infer_log_layer(net, to_infer_data)
            handle.remove()

            record_obj = {
                'name': self.name,
                'layer_name': self.layer_name,
                'epoch': len(net.history),
                'data': to_infer_data,
                'input_output': input_output
            }
            self.callback_func(record_obj)