def test_save_file(remover): trainer = get_trainer(out_dir='.') trainer._done = True w = writing.SimpleWriter() snapshot = extensions.snapshot_object(trainer, 'myfile.dat', writer=w) snapshot(trainer) assert os.path.exists('myfile.dat')
def test_simple_writer(): target = mock.MagicMock() w = writing.SimpleWriter() w.save = mock.MagicMock() with tempfile.TemporaryDirectory() as tempd: w('myfile.dat', tempd, target) assert w.save.call_count == 1
def test_simple_writer(): target = mock.MagicMock() w = writing.SimpleWriter(foo=True) savefun = mock.MagicMock() with tempfile.TemporaryDirectory() as tempd: w('myfile.dat', tempd, target, savefun=savefun) assert savefun.call_count == 1 assert savefun.call_args.args[0] == target assert savefun.call_args.kwargs['foo'] is True
def __init__( self, models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]], max_epochs: int, extensions: Optional[List['extension_module.ExtensionLike']], out_dir: str, writer: Optional[writing.Writer], stop_trigger: 'trigger_module.TriggerLike' = None ) -> None: if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = trigger_module.get_trigger( stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation: Dict[str, reporting.ReportValue] = {} self._out = out_dir self.writer = writer self.reporter = reporting.Reporter() self._start_extensions_called = False if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers( name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time: Optional[float] = None self._iters_per_epoch: Optional[int] = None self._extensions: Dict[str, _ExtensionEntry] = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out)
def test_savefun_and_writer_exclusive(): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): assert False writer = writing.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) trainer = mock.MagicMock() with pytest.raises(TypeError): extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
def __init__( self, models, optimizers, max_epochs, extensions, out_dir, writer, stop_trigger=None): if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = stop_trigger if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation = {} self._out = out_dir self.writer = writer self.reporter = Reporter() if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers( name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time = None self._extensions = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out)
def __init__( self, models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], optimizers: Union[torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]], max_epochs: int, extensions: Optional[Sequence['extension_module.ExtensionLike']], out_dir: str, writer: Optional[writing.Writer], stop_trigger: 'trigger_module.TriggerLike' = None, transform_model: _TransformModel = default_transform_model, enable_profile: bool = False, ) -> None: if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = trigger_module.get_trigger(stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation: reporting.Observation = {} self._out = out_dir self.writer = writer self.reporter = reporting.Reporter() self._transform_model = transform_model self._start_extensions_called = False self._run_on_error_called = False # Indicates whether models can be accessed from extensions in the # current iteration. # The defualt value (True) indicates that it is allowed to access # models before starting a training loop. self._model_available = True if isinstance(models, collections.abc.Mapping): self._models = models else: if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} if isinstance(optimizers, collections.abc.Mapping): self._optimizers = optimizers else: # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} for name, model in self._models.items(): # TODO we should not initialize extensions at this point # so, we cannot use `self.models` model = self._transform_model(name, model) self.reporter.add_observer(name, model) self.reporter.add_observers(name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time: Optional[float] = None self.__iters_per_epoch: Optional[int] = None self._extensions: Dict[ str, extension_module.ExtensionEntry] = collections.OrderedDict() for ext in extensions: self.extend(ext) self._enable_profile = enable_profile # Initialize the writer self.writer.initialize(self.out)