def main(args, _=None): db_fn = RedisDB if args.db == "redis" else MongoDB db = db_fn(host=args.host, port=args.port) db_len = db.num_trajectories trajectories = [] for i in tqdm(range(args.start_from, db_len)): if args.db == "redis": trajectory = db.get_trajectory(i) else: # mongo does not support indexing yet trajectory = db.get_trajectory() if args.min_reward is not None \ and sum(trajectory["trajectory"][-2]) < args.min_reward: continue trajectory = structed2dict_trajectory(trajectory) trajectory = utils.pack(trajectory) trajectories.append(trajectory) if args.chunk_size is not None \ and (i - args.start_from) % args.chunk_size == 0: with open(args.out_pkl.format(suffix=i), "wb") as fout: pickle.dump(trajectories, fout) trajectories = [] with open(args.out_pkl.format(suffix=i), "wb") as fout: pickle.dump(trajectories, fout)
def add_trajectory(self, trajectory, raw=False): try: trajectory_ = structed2dict_trajectory(trajectory) trajectory_ = pack(trajectory_) collection = (self._raw_trajectory_collection if raw else self._trajectory_collection) collection.insert_one({ "trajectory": trajectory_, "date": datetime.datetime.utcnow(), "epoch": self._epoch, }) except pymongo.errors.AutoReconnect: time.sleep(self._reconnect_timeout) return self.add_trajectory(trajectory, raw)
def add_checkpoint(self, checkpoint, epoch): try: self._epoch = epoch checkpoint_ = pack(checkpoint) if self._checkpoint_collection.exists( {"filename": "checkpoint"}): self.del_checkpoint() self._checkpoint_collection.put(checkpoint_, encoding="ascii", filename="checkpoint", epoch=self._epoch) except pymongo.errors.AutoReconnect: time.sleep(self._reconnect_timeout) return self.add_checkpoint(checkpoint, epoch)
def main(args, _=None): db = Redis(host=args.host, port=args.port) redis_len = db.llen("trajectories") - 1 trajectories = [] for i in tqdm(range(args.start_from, redis_len)): trajectory = db.lindex("trajectories", i) if args.min_reward is not None: trajectory = utils.unpack(trajectory) if sum(trajectory["trajectory"][-2]) > args.min_reward: trajectory = utils.pack(trajectory) trajectories.append(trajectory) else: trajectories.append(trajectory) if args.chunk_size is not None \ and (i - args.start_from) % args.chunk_size == 0: with open(args.out_pkl.format(suffix=i), "wb") as fout: pickle.dump(trajectories, fout) trajectories = [] with open(args.out_pkl.format(suffix=i), "wb") as fout: pickle.dump(trajectories, fout)
def add_checkpoint(self, checkpoint, epoch): self._epoch = epoch checkpoint = {"checkpoint": checkpoint, "epoch": self._epoch} checkpoint = pack(checkpoint) self._server.set(f"{self._prefix}_checkpoint", checkpoint)
def add_trajectory(self, trajectory: Trajectory, raw=False): trajectory = structed2dict_trajectory(trajectory) trajectory = {"trajectory": trajectory, "epoch": self._epoch} trajectory = pack(trajectory) name = "raw_trajectories" if raw else "trajectories" self._server.rpush(name, trajectory)