Ejemplo n.º 1
0
    def save(self, filename, outdir, target, savefun, **kwds):
        prefix = 'tmp' + filename
        with utils.tempdir(prefix=prefix, dir=outdir) as tmpdir:
            tmppath = os.path.join(tmpdir, filename)
            savefun(target, tmppath)
            shutil.move(tmppath, os.path.join(outdir, filename))

        self._post_save()
    def test_call(self):
        target = mock.MagicMock()
        w = snapshot_writers.SimpleWriter()
        w.save = mock.MagicMock()
        with utils.tempdir() as tempd:
            w('myfile.dat', tempd, target)

        assert w.save.call_count == 1
    def test_call(self):
        target = mock.MagicMock()
        w = snapshot_writers.StandardWriter()
        worker = mock.MagicMock()
        name = snapshot_writers_path + '.StandardWriter.create_worker'
        with mock.patch(name, return_value=worker):
            with utils.tempdir() as tempd:
                w('myfile.dat', tempd, target)
                w('myfile.dat', tempd, target)
                w.finalize()

            assert worker.start.call_count == 2
            assert worker.join.call_count == 2
Ejemplo n.º 4
0
def save_and_load_pth(src, dst):
    """Saves ``src`` to an PTH file and loads it to ``dst``.

    This is a short cut of :func:`save_and_load` using PTH de/serializers.

    Args:
        src: An object to save.
        dst: An object to load to.

    """
    with utils.tempdir() as tempdir:
        path = os.path.join(tempdir, 'tmp.pth')
        torch.save(src.state_dict(), path)
        dst.load_state_dict(torch.load(path))
Ejemplo n.º 5
0
    def __call__(self, trainer):
        # accumulate the observations
        keys = self._keys
        observation = trainer.observation
        summary = self._summary

        if keys is None:
            summary.add(observation)
        else:
            summary.add({k: observation[k] for k in keys if k in observation})

        if trainer.is_before_training or self._trigger(trainer):
            # output the result
            stats = self._summary.compute_mean()
            stats_cpu = {}
            for name, value in six.iteritems(stats):
                stats_cpu[name] = float(value)  # copy to CPU

            updater = trainer.updater
            stats_cpu['epoch'] = updater.epoch
            stats_cpu['iteration'] = updater.iteration
            stats_cpu['elapsed_time'] = trainer.elapsed_time

            if self._postprocess is not None:
                self._postprocess(stats_cpu)

            self._log.append(stats_cpu)

            # write to the log file
            if self._log_name is not None:
                log_name = self._log_name.format(**stats_cpu)
                with utils.tempdir(prefix=log_name, dir=trainer.out) as tempd:
                    path = os.path.join(tempd, 'log.json')
                    with open(path, 'w') as f:
                        json.dump(self._log, f, indent=4)

                    new_path = os.path.join(trainer.out, log_name)
                    shutil.move(path, new_path)

            # reset the summary for the next output
            self._init_summary()
    def test_call(self):
        target = mock.MagicMock()
        q = mock.MagicMock()
        consumer = mock.MagicMock()
        names = [
            snapshot_writers_path + '.QueueWriter.create_queue',
            snapshot_writers_path + '.QueueWriter.create_consumer'
        ]
        with mock.patch(names[0], return_value=q):
            with mock.patch(names[1], return_value=consumer):
                w = snapshot_writers.QueueWriter()

                with utils.tempdir() as tempd:
                    w('myfile.dat', tempd, target)
                    w('myfile.dat', tempd, target)
                    w.finalize()

                assert consumer.start.call_count == 1
                assert q.put.call_count == 3
                assert q.join.call_count, 1
                assert consumer.join.call_count == 1
 def test_create_worker(self):
     target = mock.MagicMock()
     w = snapshot_writers.ProcessWriter()
     with utils.tempdir() as tempd:
         worker = w.create_worker('myfile.dat', tempd, target)
         assert isinstance(worker, multiprocessing.Process)