def test_load(self, load_mode, last_epoch):
        snapshotter = Snapshotter()
        saved = snapshotter.load(self.temp_dir.name, load_mode)

        assert isinstance(saved['algo'], VPG)
        assert isinstance(saved['env'], MetaRLEnv)
        assert isinstance(saved['algo'].policy, CategoricalMLPPolicy)
        assert saved['stats'].total_epoch == last_epoch
Пример #2
0
    def test_get_available_itrs(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            many, one, none = [
                tempfile.mkdtemp(dir=temp_dir) for _ in range(3)
            ]

            open(osp.join(many, 'itr_1.pkl'), 'a').close()
            open(osp.join(many, 'itr_3.pkl'), 'a').close()
            open(osp.join(many, 'itr_5.pkl'), 'a').close()
            assert Snapshotter.get_available_itrs(many) == [1, 3, 5]

            open(osp.join(one, 'params.pkl'), 'a').close()
            assert Snapshotter.get_available_itrs(one) == ['last']

            assert not Snapshotter.get_available_itrs(none)
Пример #3
0
    def test_snapshotter(self, mode, files):
        snapshotter = Snapshotter(self.temp_dir.name, mode, 2)

        assert snapshotter.snapshot_dir == self.temp_dir.name
        assert snapshotter.snapshot_mode == mode
        assert snapshotter.snapshot_gap == 2

        snapshot_data = [{'testparam': 1}, {'testparam': 4}]
        snapshotter.save_itr_params(1, snapshot_data[0])
        snapshotter.save_itr_params(2, snapshot_data[1])

        for f, num in files.items():
            filename = osp.join(self.temp_dir.name, f)
            assert osp.exists(filename)
            with open(filename, 'rb') as pkl_file:
                data = pickle.load(pkl_file)
                assert data == snapshot_data[num]
Пример #4
0
    def test_many_folders(self, folders, workers, skip_existing, to_merge,
                          stride):
        for meta_train_dir in folders:
            itrs = Snapshotter.get_available_itrs(meta_train_dir)
            tested_itrs = self._get_tested_itrs(meta_train_dir)

            if skip_existing:
                itrs = [itr for itr in itrs if itr not in tested_itrs]

            if stride > 1:
                itrs = itrs[::stride]

            if workers == 0:
                self.test_one_folder(meta_train_dir, itrs)
            else:
                bite_size = math.ceil(len(itrs) / workers)
                bites = [
                    itrs[i * bite_size:(i + 1) * bite_size]
                    for i in range(workers)
                ]

                children = []
                for bite_itrs in bites:
                    if len(bite_itrs) == 0:
                        continue
                    pid = os.fork()
                    if pid == 0:
                        # In child process
                        self.test_one_folder(meta_train_dir, bite_itrs)
                        exit()
                    else:
                        # In parent process
                        children.append(pid)

                for child in children:
                    os.waitpid(child, 0)

            if to_merge:
                self._merge_csv(meta_train_dir, itrs)
 def test_load_with_invalid_load_mode(self):
     snapshotter = Snapshotter()
     with pytest.raises(ValueError):
         snapshotter.load(self.temp_dir.name, 'foo')
Пример #6
0
 def test_invalid_snapshot_mode(self):
     with pytest.raises(ValueError):
         snapshotter = Snapshotter(snapshot_dir=self.temp_dir.name,
                                   snapshot_mode='invalid')
         snapshotter.save_itr_params(2, {'testparam': 'invalid'})