コード例 #1
0
ファイル: _trainer.py プロジェクト: shinh/pytorch-pfn-extras
    def __init__(
        self,
        handler: 'handler_module.BaseHandler',
        *,
        evaluator: Optional[Union['Evaluator', Tuple['Evaluator',
                                                     TriggerLike]]],
        models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
        **kwargs: Any,
    ):
        self.handler = handler
        self._manager: Optional['training.ExtensionsManager'] = None

        # The followings are used when setting up a manager instance
        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        self._kwargs = kwargs
        self._extensions: List[  # list of (args, kwargs)
            Tuple[Tuple['training.Extension', Optional[str], 'TriggerLike',
                        Optional[int]], Dict[str, Any]]] = []
        self._manager_state: Optional[Dict[str, Any]] = None

        if isinstance(evaluator, tuple):
            self.evaluator: Optional['Evaluator'] = None
            self.evaluator, trigger = evaluator
            self.evaluator_trigger = trigger_module.get_trigger(trigger)
        else:
            self.evaluator = evaluator
            self.evaluator_trigger = trigger_module.get_trigger((1, 'epoch'))
        self.val_loader = None
コード例 #2
0
    def __init__(self,
                 check_trigger=(1, 'epoch'),
                 monitor='main/loss',
                 patience=None,
                 mode='auto',
                 verbose=False,
                 max_trigger=(100, 'epoch'),
                 **kwargs):

        # `patients` as an alias of `patience`
        patients = kwargs.get('patients', None)
        if patients is None:
            if patience is None:
                patience = 3
            else:
                pass
        else:
            if patience is None:
                patience = patients
            else:
                raise TypeError(
                    'Both \'patience\' and \'patients\' arguments are '
                    'specified. \'patients\' is an alias of the former. '
                    'Specify only \'patience\'.')

        self.count = 0
        self.patience = patience
        self.monitor = monitor
        self.verbose = verbose
        self.already_warning = False
        self._max_trigger = trigger.get_trigger(max_trigger)
        self._interval_trigger = trigger.get_trigger(check_trigger)

        self._init_summary()

        if mode == 'max':
            self._compare = operator.gt

        elif mode == 'min':
            self._compare = operator.lt

        else:
            if 'accuracy' in monitor:
                self._compare = operator.gt

            else:
                self._compare = operator.lt

        if self._compare == operator.gt:
            if verbose:
                print('early stopping: operator is greater')
            self.best = float('-inf')

        else:
            if verbose:
                print('early stopping: operator is less')
            self.best = float('inf')
コード例 #3
0
    def __init__(
            self,
            models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
            optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]],
            max_epochs: int,
            extensions: Optional[List['extension_module.ExtensionLike']],
            out_dir: str,
            writer: Optional[writing.Writer],
            stop_trigger: 'trigger_module.TriggerLike' = None
    ) -> None:
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = trigger_module.get_trigger(
                stop_trigger)
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation: Dict[str, reporting.ReportValue] = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = reporting.Reporter()
        self._start_extensions_called = False

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(
                name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time: Optional[float] = None
        self._iters_per_epoch: Optional[int] = None
        self._extensions: Dict[str, _ExtensionEntry] = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)
コード例 #4
0
 def __init__(self, handler, *, evaluator, **kwargs):
     super().__init__(handler, **kwargs)
     if type(evaluator) is tuple:
         self.evaluator, trigger = evaluator
         self.evaluator_trigger = trigger_module.get_trigger(trigger)
     else:
         self.evaluator = evaluator
         self.evaluator_trigger = trigger_module.get_trigger((1, 'epoch'))
     self.val_loader = None
コード例 #5
0
ファイル: comparer.py プロジェクト: shinh/pytorch-pfn-extras
    def __init__(
            self,
            *,
            trigger=None,
            compare_fn=_default_comparer,
            concurrency=None,
            outputs=True,
            params=False,
    ):
        """A class for comparison of iteration outputs and model parameters.

        This class is mainly used to compare results between different devices.

        Args:
            trigger (Trigger):
                Trigger object that determines when to compare values.
            compare_fn (function):
                Comparison function. Default is ``get_default_comparer()``.
            concurrency (int, optional):
                The upper bound limit on the number of workers that run concurrently.
                If ``None``, inferred from the size of ``engines``.
            outputs (tuple of str or bool):
                A set of keys of output dict to compare.
            params (tuple of str or bool):
                A set of keys of model parameters to compare.

        Examples:
            >>> trainer_cpu = ppe.engine.create_trainer(
                    model, optimizer, 1, device='cpu')
            >>> trainer_gpu = ppe.engine.create_trainer(
                    model, optimizer, 1, device='cuda:0')
            >>> comp = ppe.utils.comparer.Comparer()
            >>> comp.add_engine("cpu", engine_cpu, train_1, eval_1)
            >>> comp.add_engine("gpu", engine_gpu, train_2, eval_2)
            >>> comp.compare()
        """
        self._engine_type = None
        self._engines = collections.OrderedDict()
        self._compare_fn = compare_fn
        self._targets = {}
        self._output_keys = outputs
        self._param_keys = params
        self._finalized = False
        self._concurrency = concurrency  # Upper limit of semaphore size
        self._semaphore = None  # Sempaphore for training step execution
        self._barrier = None  # Synchronizes iteration timing
        self._report_lock = threading.Lock()  # Locks `Comparer._add_target`

        if trigger is None:
            self._trigger = trigger_module.get_trigger((1, "epoch"))
        else:
            self._engine_type = _trainer.Trainer
            self._trigger = trigger_module.get_trigger(trigger)
コード例 #6
0
    def __init__(self,
                 models,
                 optimizers,
                 max_epochs,
                 extensions,
                 out_dir,
                 writer,
                 stop_trigger=None):
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = trigger_module.get_trigger(stop_trigger)
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = Reporter()

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time = None
        self._extensions = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)
コード例 #7
0
    def __init__(self,
                 keys=None,
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 append=False,
                 format=None,
                 **kwargs):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log = []
        # When using a writer, it needs to have a savefun defined
        # to deal with a string.
        self._writer = kwargs.get('writer', None)

        log_name = kwargs.get('log_name', 'log')
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename

        if format is None and filename is not None:
            if filename.endswith('.jsonl'):
                format = 'json-lines'
            elif filename.endswith('.yaml'):
                format = 'yaml'
            else:
                format = 'json'

        self._append = append
        self._format = format
        self._init_summary()
コード例 #8
0
    def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 marker='x',
                 grid=True,
                 **kwargs):

        file_name = kwargs.get('file_name', 'plot.png')
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = filename
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
        self._writer = kwargs.get('writer', None)
コード例 #9
0
    def __init__(self, links, statistics='default',
                 report_params=True, report_grads=True, prefix=None,
                 trigger=(1, 'epoch'), skip_nan_params=False):

        if not isinstance(links, (list, tuple)):
            links = links,
        self._links = links

        if statistics is None:
            statistics = {}
        elif statistics == 'default':
            statistics = self.default_statistics
        self._statistics = dict(statistics)

        attrs = []
        if report_params:
            attrs.append('data')
        if report_grads:
            attrs.append('grad')
        self._attrs = attrs

        self._prefix = prefix
        self._trigger = trigger_module.get_trigger(trigger)
        self._summary = reporting.DictSummary()
        self._skip_nan_params = skip_nan_params
コード例 #10
0
    def __init__(
        self,
        y_keys: Union[Iterable[str], str],
        x_key: str = 'iteration',
        trigger: trigger_module.TriggerLike = (1, 'epoch'),
        postprocess: Any = None,
        filename: Optional[str] = None,
        marker: str = 'x',
        grid: bool = True,
        **kwargs: Any,
    ):

        file_name = kwargs.get('file_name', 'plot.png')
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = filename
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data: Dict[str, List[Tuple[Any, Any]]] = {k: [] for k in y_keys}
        self._writer = kwargs.get('writer', None)
コード例 #11
0
 def __init__(self,
              scheduler,
              *,
              stepper=_default_stepper,
              trigger=(1, 'epoch')):
     self.scheduler = scheduler
     self.trigger = trigger_module.get_trigger(trigger)
     self.stepper = stepper
コード例 #12
0
    def __init__(self, targets, max_sample_size=1000,
                 report_data=True, report_grad=True,
                 plot_mean=True, plot_std=True,
                 percentile_sigmas=(
                     0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100),
                 trigger=(1, 'epoch'), filename=None,
                 figsize=None, marker=None, grid=True, **kwargs):

        _check_available()

        file_name = kwargs.get('file_name', 'statistics.png')
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        self._vars = _unpack_variables(targets)
        if not self._vars:
            raise ValueError(
                'Need at least one variables for which to collect statistics.'
                '\nActual: 0 <= 0')

        if not any((plot_mean, plot_std, bool(percentile_sigmas))):
            raise ValueError('Nothing to plot')

        self._keys = []
        if report_data:
            self._keys.append('data')
        if report_grad:
            self._keys.append('grad')

        self._report_data = report_data
        self._report_grad = report_grad

        self._statistician = Statistician(
            collect_mean=plot_mean, collect_std=plot_std,
            percentile_sigmas=percentile_sigmas)

        self._plot_mean = plot_mean
        self._plot_std = plot_std
        self._plot_percentile = bool(percentile_sigmas)

        self._trigger = trigger_module.get_trigger(trigger)
        self._filename = filename
        self._figsize = figsize
        self._marker = marker
        self._grid = grid
        self._writer = kwargs.get('writer', None)

        if not self._plot_percentile:
            n_percentile = 0
        else:
            if not isinstance(percentile_sigmas, (list, tuple)):
                n_percentile = 1  # scalar, single percentile
            else:
                n_percentile = len(percentile_sigmas)
        self._data_shape = (
            len(self._keys), int(plot_mean) + int(plot_std) + n_percentile)
        self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
コード例 #13
0
 def __init__(
         self,
         key: str,
         compare: Callable[[float, float], bool],
         trigger: 'TriggerLike' = (1, 'epoch'),
 ) -> None:
     self._key = key
     self._best_value: Optional[float] = None
     self._interval_trigger = trigger_module.get_trigger(trigger)
     self._init_summary()
     self._compare = compare
コード例 #14
0
 def __init__(
     self,
     scheduler: Any,
     *,
     stepper: Any = _default_stepper,
     trigger: trigger_module.TriggerLike = (1, 'epoch'),
     is_async: bool = True,
 ) -> None:
     self.scheduler = scheduler
     self.trigger = trigger_module.get_trigger(trigger)
     self.stepper = stepper
     self.is_async = is_async
コード例 #15
0
    def __init__(self,
                 numerator_key,
                 denominator_key,
                 result_key,
                 trigger=(1, 'epoch')):
        self._trigger = trigger_module.get_trigger(trigger)

        self._numerator_key = numerator_key
        self._denominator_key = denominator_key
        self._result_key = result_key
        self._numerator = 0
        self._denominator = 0
コード例 #16
0
    def __init__(
            self,
            numerator_key: str,
            denominator_key: str,
            result_key: str,
            trigger: trigger_module.TriggerLike = (1, 'epoch'),
    ) -> None:
        self._trigger = trigger_module.get_trigger(trigger)

        self._numerator_key = numerator_key
        self._denominator_key = denominator_key
        self._result_key = result_key
        self._numerator = 0.
        self._denominator = 0.
コード例 #17
0
    def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None,
                 filename=None, **kwargs):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log = []
        # When using a writer, it needs to have a savefun defined
        # to deal with a string.
        self._writer = kwargs.get('writer', None)

        log_name = kwargs.get('log_name', 'log')
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename

        self._init_summary()
コード例 #18
0
    def __init__(
        self,
        store_keys: Optional[Iterable[str]] = None,
        report_keys: Optional[Iterable[str]] = None,
        trigger: trigger_module.TriggerLike = (1, "epoch"),
        filename: Optional[str] = None,
        append: bool = False,
        format: Optional[str] = None,
        **kwargs: Any,
    ):
        self.time_summary = get_time_summary()
        # Initializes global TimeSummary.
        self.time_summary.initialize()

        if store_keys is None:
            self._store_keys = store_keys
        else:
            self._store_keys = list(store_keys) + [
                key + ".std" for key in store_keys
            ]
        self._report_keys = report_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._log: List[Any] = []

        log_name = kwargs.get("log_name", "log")
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename
        self._writer = kwargs.get('writer', None)

        if format is None and filename is not None:
            if filename.endswith('.jsonl'):
                format = 'json-lines'
            elif filename.endswith('.yaml'):
                format = 'yaml'
            else:
                format = 'json'

        self._append = append
        self._format = format
コード例 #19
0
    def __init__(
            self,
            check_trigger: 'TriggerLike' = (1, 'epoch'),
            monitor: str = 'main/loss',
            patience: int = 3,
            mode: str = 'auto',
            verbose: bool = False,
            max_trigger: Tuple[int, 'UnitLiteral'] = (100, 'epoch'),
    ) -> None:
        self.count = 0
        self.patience = patience
        self.monitor = monitor
        self.verbose = verbose
        self.already_warning = False
        self._max_trigger = trigger.IntervalTrigger(*max_trigger)
        self._interval_trigger = trigger.get_trigger(check_trigger)

        self._init_summary()

        if mode == 'max':
            self._compare = operator.gt

        elif mode == 'min':
            self._compare = operator.lt

        else:
            if 'accuracy' in monitor:
                self._compare = operator.gt

            else:
                self._compare = operator.lt

        if self._compare == operator.gt:
            if verbose:
                print('early stopping: operator is greater')
            self.best = float('-inf')

        else:
            if verbose:
                print('early stopping: operator is less')
            self.best = float('inf')
コード例 #20
0
    def __init__(
        self,
        store_keys=None,
        report_keys=None,
        trigger=(1, "epoch"),
        filename=None,
        append=False,
        format=None,
        **kwargs,
    ):
        # Initializes global TimeSummary.
        time_summary.initialize()

        if store_keys is None:
            self._store_keys = store_keys
        else:
            self._store_keys = list(store_keys) + [
                key + ".std" for key in store_keys
            ]
        self._report_keys = report_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._log = []

        log_name = kwargs.get("log_name", "log")
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename
        self._writer = kwargs.get('writer', None)

        if format is None and filename is not None:
            if filename.endswith('.jsonl'):
                format = 'json-lines'
            elif filename.endswith('.yaml'):
                format = 'yaml'
            else:
                format = 'json'

        self._append = append
        self._format = format
コード例 #21
0
    def __init__(
        self,
        keys: Optional[Iterable[str]] = None,
        trigger: trigger_module.TriggerLike = (1, 'epoch'),
        postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None,
        filename: Optional[str] = None,
        append: bool = False,
        format: Optional[str] = None,
        **kwargs: Any,
    ):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log_buffer = _LogBuffer()
        self._log_looker = self._log_buffer.emit_new_looker()
        # When using a writer, it needs to have a savefun defined
        # to deal with a string.
        self._writer = kwargs.get('writer', None)

        if filename is None:
            filename = 'log'

        if format is None:
            if filename.endswith('.jsonl'):
                format = 'json-lines'
            elif filename.endswith('.yaml'):
                format = 'yaml'
            else:
                format = 'json'
        elif format not in ('json', 'json-lines', 'yaml'):
            raise ValueError(f'unsupported log format: {format}')

        self._filename = filename
        self._append = append
        self._format = format
        self._init_summary()
コード例 #22
0
    def __init__(
        self,
        models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
        optimizers: Union[torch.optim.Optimizer,
                          Mapping[str, torch.optim.Optimizer]],
        max_epochs: int,
        extensions: Optional[Sequence['extension_module.ExtensionLike']],
        out_dir: str,
        writer: Optional[writing.Writer],
        stop_trigger: 'trigger_module.TriggerLike' = None,
        transform_model: _TransformModel = default_transform_model,
        enable_profile: bool = False,
    ) -> None:
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = trigger_module.get_trigger(stop_trigger)
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation: reporting.Observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = reporting.Reporter()
        self._transform_model = transform_model
        self._start_extensions_called = False
        self._run_on_error_called = False

        # Indicates whether models can be accessed from extensions in the
        # current iteration.
        # The defualt value (True) indicates that it is allowed to access
        # models before starting a training loop.
        self._model_available = True

        if isinstance(models, collections.abc.Mapping):
            self._models = models
        else:
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        if isinstance(optimizers, collections.abc.Mapping):
            self._optimizers = optimizers
        else:
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}

        for name, model in self._models.items():
            # TODO we should not initialize extensions at this point
            # so, we cannot use `self.models`
            model = self._transform_model(name, model)
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time: Optional[float] = None
        self.__iters_per_epoch: Optional[int] = None
        self._extensions: Dict[
            str, extension_module.ExtensionEntry] = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        self._enable_profile = enable_profile
        # Initialize the writer
        self.writer.initialize(self.out)
コード例 #23
0
 def __init__(self, key, compare, trigger=(1, 'epoch')):
     self._key = key
     self._best_value = None
     self._interval_trigger = trigger_module.get_trigger(trigger)
     self._init_summary()
     self._compare = compare
コード例 #24
0
    def run(self,
            train_loader: Iterable[Any],
            val_loader: Optional[Iterable[Any]] = None,
            *,
            train_len: Optional[int] = None,
            eval_len: Optional[int] = None) -> None:
        """Executes the training loop.

        Args:
            train_loader (torch.utils.data.DataLoader):
                A data loader for training.
            val_loader (torch.utils.data.DataLoader, optional):
                A data loader passed to ``Evaluator.run()``.
            train_len (int, optional):
                The number of iterations per one training epoch. The default
                value is inferred from the size of training data loader.
            eval_len (int, optional):
                The number of iterations per one evaluation epoch, passed
                to ``Evaluator.run()``

        .. seealso::
            - :meth:`pytorch_pfn_extras.training._evaluator.Evaluator`
        """
        if train_len is None:
            train_len = len(train_loader)  # type: ignore[arg-type]
        if eval_len is None and val_loader is not None:
            eval_len = len(val_loader)  # type: ignore[arg-type]

        self._train_len = train_len
        self._eval_len = eval_len

        class _EvaluatorExt:
            def __init__(
                    self,
                    trainer: 'Trainer',
                    evaluator: 'Evaluator',
                    val_loader: Optional[Iterable[Any]],
                    eval_len: Optional[int],
            ) -> None:
                self.needs_model_state = True
                self._trainer = trainer
                self._evaluator = evaluator
                self._val_loader = val_loader
                self._eval_len = eval_len

            def __call__(self, manager: ExtensionsManagerProtocol) -> None:
                evaluator = self._evaluator
                if self._val_loader is None:
                    raise ValueError('"val_loader" is not given.')
                evaluator.handler.train_validation_begin(self._trainer, evaluator)
                evaluator.run(self._val_loader, eval_len=self._eval_len)
                evaluator.handler.train_validation_end(self._trainer, evaluator)

        if self._manager is None:
            self._manager = self._setup_manager(train_len)
            for name, (evaluator, trigger) in self._evaluators.items():
                # Register the evaluator as an extension to the manager
                # To be triggered with the correct timing
                self._manager.extend(
                    _EvaluatorExt(self, evaluator, val_loader, eval_len),
                    name=name,
                    trigger=trigger_module.get_trigger(trigger),
                    priority=extension.PRIORITY_WRITER,
                )
            self.handler.train_setup(self, train_loader)
            if len(self._evaluators) == 0:
                if val_loader is not None:
                    warnings.warn(
                        '`val_loader` is given whereas the evaluator is missing.',
                        UserWarning)
            else:
                if val_loader is None:
                    raise ValueError('`val_loader` is required')
                for _, (evaluator, _) in self._evaluators.items():
                    evaluator.handler.eval_setup(evaluator, val_loader)

        with self._profile or _nullcontext() as prof:
            while not self.manager.stop_trigger:
                self.handler.train_epoch_begin(self, train_loader)

                # When iterations are completed in the callback
                # This is needed to avoid being constantly passing parameters
                self._idxs: 'queue.Queue[int]' = queue.Queue()
                self._inputs: 'queue.Queue[Any]' = queue.Queue()
                self._times: 'queue.Queue[float]' = queue.Queue()
                self._observed: 'queue.Queue[reporting.Observation]' = queue.Queue()
                # Iterator must be created after `train_epoch_begin` as it may be
                #  using a DistributedSampler.
                loader_iter = iter(train_loader)
                self._profile_records: 'queue.Queue[List[_ReportNotification]]' \
                    = queue.Queue()
                for idx in range(train_len):
                    with record(
                        "pytorch_pfn_extras.training.Trainer:iteration",
                        use_cuda=torch.cuda.is_available(),
                        enable=self._enable_profile
                    ) as ntf0:
                        try:
                            with record(
                                "pytorch_pfn_extras.training.Trainer:get_data",
                                enable=self._enable_profile
                            ):
                                x = next(loader_iter)
                        except StopIteration:
                            loader_iter = iter(train_loader)
                            with record(
                                "pytorch_pfn_extras.training.Trainer:get_data",
                                enable=self._enable_profile
                            ):
                                x = next(loader_iter)
                        begin = time.time()
                        self._idxs.put(idx)
                        self._inputs.put(x)
                        self._times.put(begin)
                        with record(
                            "pytorch_pfn_extras.training.Trainer:run_iteration",
                            use_cuda=torch.cuda.is_available(),
                            enable=self._enable_profile
                        ) as ntf1, \
                                self.manager.run_iteration():
                            self._observed.put(self.manager.observation)
                            with record(
                                "pytorch_pfn_extras.training.Trainer:train_step",
                                use_cuda=torch.cuda.is_available(),
                                enable=self._enable_profile
                            ) as ntf2:
                                self._profile_records.put([ntf0, ntf1, ntf2])
                                self.handler.train_step(
                                    self, idx, x, complete_fn=self._complete_step)
                                # Check if the callback was called
                    if prof is not None:
                        prof.step()  # type: ignore[no-untyped-call]
                    # In some cases, DataLoaders are continuos
                    # And will keep yielding results even if the epoch
                    # is completed. We forcefully exit at the end of
                    # every epoch
                    if self.is_epoch_last_iter(idx) or self.manager.stop_trigger:
                        break
                # In handlers that support a completely Async model train_epoch_end
                # Will take care of completing pending work
                self.handler.train_epoch_end(self)
            if prof is not None:
                prof.on_trace_ready = None
        self.handler.train_cleanup(self)
コード例 #25
0
ファイル: manager.py プロジェクト: okdshin/pytorch-pfn-extras
    def extend(self,
               extension,
               name=None,
               trigger=None,
               priority=None,
               *,
               call_before_training=False,
               **kwargs):
        """Registers an extension to the manager.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                :attr:`Extension.name` attribute of the extension is used or
                the :attr:`Extension.default_name` attribute of the extension
                if `name` is is set to `None` or is undefined.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            call_before_training (bool): Flag to call extension before
                training. Default is ``False``.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(extension, 'priority',
                               extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger, call_before_training)
コード例 #26
0
    def extend(
        self,
        extension: 'extension_module.ExtensionLike',
        name: Optional[str] = None,
        trigger: 'trigger_module.TriggerLike' = None,
        priority: Optional[int] = None,
        *,
        call_before_training: bool = False,
        **kwargs: Dict[str, Any],
    ) -> None:
        """Registers an extension to the manager.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                :attr:`Extension.name` attribute of the extension is used or
                the :attr:`Extension.default_name` attribute of the extension
                if `name` is is set to `None` or is undefined.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            call_before_training (bool): Flag to call extension before
                training. Default is ``False``.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        if self._start_extensions_called:
            raise RuntimeError(
                'extend called after the extensions were initialized')
        ext = extension_module._as_extension(extension)
        if name is None:
            name = ext.name or ext.default_name
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = ext.trigger
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = ext.priority

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        ext.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            ext, priority, trigger, call_before_training)
コード例 #27
0
ファイル: comparer.py プロジェクト: pfnet/pytorch-pfn-extras
    def __init__(
        self,
        *,
        trigger: Optional[trigger_module.TriggerLike] = None,
        compare_fn: _CompareFn = _default_comparer,
        concurrency: Optional[int] = None,
        outputs: Union[bool, str, Sequence[str]] = True,
        params: Union[bool, str, Sequence[str]] = False,
        baseline: Optional[str] = None,
    ) -> None:
        """A class for comparison of iteration outputs and model parameters.

        This class is mainly used to compare results between different devices.

        Args:
            trigger (Trigger):
                Trigger object that determines when to compare values.
            compare_fn (function):
                Comparison function. Default is ``get_default_comparer()``.
            concurrency (int, optional):
                The upper bound limit on the number of workers that run concurrently.
                If ``None``, inferred from the size of ``engines``.
            outputs (tuple of str or bool):
                A set of keys of output dict to compare.
            params (tuple of str or bool):
                A set of keys of model parameters to compare.
            baseline (str, optional):
                The baseline engine that is assumed to be correct.

        Examples:
            >>> trainer_cpu = ppe.engine.create_trainer(
                    model, optimizer, 1, device='cpu')
            >>> trainer_gpu = ppe.engine.create_trainer(
                    model, optimizer, 1, device='cuda:0')
            >>> comp = ppe.utils.comparer.Comparer()
            >>> comp.add_engine("cpu", engine_cpu, train_1, eval_1)
            >>> comp.add_engine("gpu", engine_gpu, train_2, eval_2)
            >>> comp.compare()
        """
        self._engine_type: Optional[Type[_Engine]] = None
        self._engines: Dict[str, Tuple[Union[_Engine, _LoadDumpsEngine], Any,
                                       Any]] = collections.OrderedDict()
        self._compare_fn = compare_fn
        self._targets: Dict[str, Dict[str, Any]] = {}
        self._output_keys = outputs
        self._param_keys = params
        self._baseline = baseline
        self._finalized = False
        self._concurrency = concurrency  # Upper limit of semaphore size
        # Sempaphore for training step execution
        self._semaphore: Optional[threading.Semaphore] = None
        # Synchronizes iteration timing
        self._barrier: Optional[threading.Barrier] = None
        self._report_lock = threading.Lock()  # Locks `Comparer._get_target`
        self._count = 0

        if trigger is None:
            self._trigger = trigger_module.get_trigger((1, "epoch"))
        else:
            self._engine_type = _trainer.Trainer
            self._trigger = trigger_module.get_trigger(trigger)