def save(self, filename, outdir, target, savefun, **kwds): prefix = 'tmp' + filename with utils.tempdir(prefix=prefix, dir=outdir) as tmpdir: tmppath = os.path.join(tmpdir, filename) savefun(target, tmppath) shutil.move(tmppath, os.path.join(outdir, filename)) self._post_save()
def test_call(self): target = mock.MagicMock() w = snapshot_writers.SimpleWriter() w.save = mock.MagicMock() with utils.tempdir() as tempd: w('myfile.dat', tempd, target) assert w.save.call_count == 1
def test_call(self): target = mock.MagicMock() w = snapshot_writers.StandardWriter() worker = mock.MagicMock() name = snapshot_writers_path + '.StandardWriter.create_worker' with mock.patch(name, return_value=worker): with utils.tempdir() as tempd: w('myfile.dat', tempd, target) w('myfile.dat', tempd, target) w.finalize() assert worker.start.call_count == 2 assert worker.join.call_count == 2
def save_and_load_pth(src, dst): """Saves ``src`` to an PTH file and loads it to ``dst``. This is a short cut of :func:`save_and_load` using PTH de/serializers. Args: src: An object to save. dst: An object to load to. """ with utils.tempdir() as tempdir: path = os.path.join(tempdir, 'tmp.pth') torch.save(src.state_dict(), path) dst.load_state_dict(torch.load(path))
def __call__(self, trainer): # accumulate the observations keys = self._keys observation = trainer.observation summary = self._summary if keys is None: summary.add(observation) else: summary.add({k: observation[k] for k in keys if k in observation}) if trainer.is_before_training or self._trigger(trainer): # output the result stats = self._summary.compute_mean() stats_cpu = {} for name, value in six.iteritems(stats): stats_cpu[name] = float(value) # copy to CPU updater = trainer.updater stats_cpu['epoch'] = updater.epoch stats_cpu['iteration'] = updater.iteration stats_cpu['elapsed_time'] = trainer.elapsed_time if self._postprocess is not None: self._postprocess(stats_cpu) self._log.append(stats_cpu) # write to the log file if self._log_name is not None: log_name = self._log_name.format(**stats_cpu) with utils.tempdir(prefix=log_name, dir=trainer.out) as tempd: path = os.path.join(tempd, 'log.json') with open(path, 'w') as f: json.dump(self._log, f, indent=4) new_path = os.path.join(trainer.out, log_name) shutil.move(path, new_path) # reset the summary for the next output self._init_summary()
def test_call(self): target = mock.MagicMock() q = mock.MagicMock() consumer = mock.MagicMock() names = [ snapshot_writers_path + '.QueueWriter.create_queue', snapshot_writers_path + '.QueueWriter.create_consumer' ] with mock.patch(names[0], return_value=q): with mock.patch(names[1], return_value=consumer): w = snapshot_writers.QueueWriter() with utils.tempdir() as tempd: w('myfile.dat', tempd, target) w('myfile.dat', tempd, target) w.finalize() assert consumer.start.call_count == 1 assert q.put.call_count == 3 assert q.join.call_count, 1 assert consumer.join.call_count == 1
def test_create_worker(self): target = mock.MagicMock() w = snapshot_writers.ProcessWriter() with utils.tempdir() as tempd: worker = w.create_worker('myfile.dat', tempd, target) assert isinstance(worker, multiprocessing.Process)