Ejemplo n.º 1
0
def test_save_file(remover):
    trainer = get_trainer(out_dir='.')
    trainer._done = True
    w = writing.SimpleWriter()
    snapshot = extensions.snapshot_object(trainer, 'myfile.dat', writer=w)
    snapshot(trainer)

    assert os.path.exists('myfile.dat')
Ejemplo n.º 2
0
def test_simple_writer():
    target = mock.MagicMock()
    w = writing.SimpleWriter()
    w.save = mock.MagicMock()
    with tempfile.TemporaryDirectory() as tempd:
        w('myfile.dat', tempd, target)

    assert w.save.call_count == 1
Ejemplo n.º 3
0
def test_simple_writer():
    target = mock.MagicMock()
    w = writing.SimpleWriter(foo=True)
    savefun = mock.MagicMock()
    with tempfile.TemporaryDirectory() as tempd:
        w('myfile.dat', tempd, target, savefun=savefun)
    assert savefun.call_count == 1
    assert savefun.call_args.args[0] == target
    assert savefun.call_args.kwargs['foo'] is True
Ejemplo n.º 4
0
    def __init__(
            self,
            models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
            optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]],
            max_epochs: int,
            extensions: Optional[List['extension_module.ExtensionLike']],
            out_dir: str,
            writer: Optional[writing.Writer],
            stop_trigger: 'trigger_module.TriggerLike' = None
    ) -> None:
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = trigger_module.get_trigger(
                stop_trigger)
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation: Dict[str, reporting.ReportValue] = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = reporting.Reporter()
        self._start_extensions_called = False

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(
                name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time: Optional[float] = None
        self._iters_per_epoch: Optional[int] = None
        self._extensions: Dict[str, _ExtensionEntry] = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)
Ejemplo n.º 5
0
def test_savefun_and_writer_exclusive():
    # savefun and writer arguments cannot be specified together.
    def savefun(*args, **kwargs):
        assert False
    writer = writing.SimpleWriter()
    with pytest.raises(TypeError):
        extensions.snapshot(savefun=savefun, writer=writer)

    trainer = mock.MagicMock()
    with pytest.raises(TypeError):
        extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
Ejemplo n.º 6
0
    def __init__(
            self,
            models,
            optimizers,
            max_epochs,
            extensions,
            out_dir,
            writer,
            stop_trigger=None):
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = stop_trigger
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = Reporter()

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(
                name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time = None
        self._extensions = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)
Ejemplo n.º 7
0
    def __init__(
        self,
        models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
        optimizers: Union[torch.optim.Optimizer,
                          Mapping[str, torch.optim.Optimizer]],
        max_epochs: int,
        extensions: Optional[Sequence['extension_module.ExtensionLike']],
        out_dir: str,
        writer: Optional[writing.Writer],
        stop_trigger: 'trigger_module.TriggerLike' = None,
        transform_model: _TransformModel = default_transform_model,
        enable_profile: bool = False,
    ) -> None:
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = trigger_module.get_trigger(stop_trigger)
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation: reporting.Observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = reporting.Reporter()
        self._transform_model = transform_model
        self._start_extensions_called = False
        self._run_on_error_called = False

        # Indicates whether models can be accessed from extensions in the
        # current iteration.
        # The defualt value (True) indicates that it is allowed to access
        # models before starting a training loop.
        self._model_available = True

        if isinstance(models, collections.abc.Mapping):
            self._models = models
        else:
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        if isinstance(optimizers, collections.abc.Mapping):
            self._optimizers = optimizers
        else:
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}

        for name, model in self._models.items():
            # TODO we should not initialize extensions at this point
            # so, we cannot use `self.models`
            model = self._transform_model(name, model)
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time: Optional[float] = None
        self.__iters_per_epoch: Optional[int] = None
        self._extensions: Dict[
            str, extension_module.ExtensionEntry] = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        self._enable_profile = enable_profile
        # Initialize the writer
        self.writer.initialize(self.out)