def test_load(self, load_mode, last_epoch): snapshotter = Snapshotter() saved = snapshotter.load(self.temp_dir.name, load_mode) assert isinstance(saved['algo'], VPG) assert isinstance(saved['env'], MetaRLEnv) assert isinstance(saved['algo'].policy, CategoricalMLPPolicy) assert saved['stats'].total_epoch == last_epoch
def test_get_available_itrs(self): with tempfile.TemporaryDirectory() as temp_dir: many, one, none = [ tempfile.mkdtemp(dir=temp_dir) for _ in range(3) ] open(osp.join(many, 'itr_1.pkl'), 'a').close() open(osp.join(many, 'itr_3.pkl'), 'a').close() open(osp.join(many, 'itr_5.pkl'), 'a').close() assert Snapshotter.get_available_itrs(many) == [1, 3, 5] open(osp.join(one, 'params.pkl'), 'a').close() assert Snapshotter.get_available_itrs(one) == ['last'] assert not Snapshotter.get_available_itrs(none)
def test_snapshotter(self, mode, files): snapshotter = Snapshotter(self.temp_dir.name, mode, 2) assert snapshotter.snapshot_dir == self.temp_dir.name assert snapshotter.snapshot_mode == mode assert snapshotter.snapshot_gap == 2 snapshot_data = [{'testparam': 1}, {'testparam': 4}] snapshotter.save_itr_params(1, snapshot_data[0]) snapshotter.save_itr_params(2, snapshot_data[1]) for f, num in files.items(): filename = osp.join(self.temp_dir.name, f) assert osp.exists(filename) with open(filename, 'rb') as pkl_file: data = pickle.load(pkl_file) assert data == snapshot_data[num]
def test_many_folders(self, folders, workers, skip_existing, to_merge, stride): for meta_train_dir in folders: itrs = Snapshotter.get_available_itrs(meta_train_dir) tested_itrs = self._get_tested_itrs(meta_train_dir) if skip_existing: itrs = [itr for itr in itrs if itr not in tested_itrs] if stride > 1: itrs = itrs[::stride] if workers == 0: self.test_one_folder(meta_train_dir, itrs) else: bite_size = math.ceil(len(itrs) / workers) bites = [ itrs[i * bite_size:(i + 1) * bite_size] for i in range(workers) ] children = [] for bite_itrs in bites: if len(bite_itrs) == 0: continue pid = os.fork() if pid == 0: # In child process self.test_one_folder(meta_train_dir, bite_itrs) exit() else: # In parent process children.append(pid) for child in children: os.waitpid(child, 0) if to_merge: self._merge_csv(meta_train_dir, itrs)
def test_load_with_invalid_load_mode(self): snapshotter = Snapshotter() with pytest.raises(ValueError): snapshotter.load(self.temp_dir.name, 'foo')
def test_invalid_snapshot_mode(self): with pytest.raises(ValueError): snapshotter = Snapshotter(snapshot_dir=self.temp_dir.name, snapshot_mode='invalid') snapshotter.save_itr_params(2, {'testparam': 'invalid'})