示例#1
0
文件: db.py 项目: zkid18/catalyst
 def get_checkpoint(self):
     checkpoint = self._server.get(f"{self._prefix}_checkpoint")
     if checkpoint is None:
         return None
     checkpoint = unpack(checkpoint)
     self._epoch = checkpoint.get("epoch")
     return checkpoint["checkpoint"]
示例#2
0
文件: db.py 项目: zkid18/catalyst
        def get_trajectory(self, index=None):
            assert index is None

            try:
                trajectory_obj = self._trajectory_collection.find_one(
                    {"date": {
                        "$gt": self._last_datetime
                    }})
            except pymongo.errors.AutoReconnect:
                time.sleep(self._reconnect_timeout)
                return self.get_trajectory(index)

            if trajectory_obj is not None:
                self._last_datetime = trajectory_obj["date"]

                trajectory, trajectory_epoch = (
                    unpack(trajectory_obj["trajectory"]),
                    trajectory_obj["epoch"],
                )
                if self._sync_epoch and self._epoch != trajectory_epoch:
                    trajectory = None
                else:
                    trajectory = dict2structed_trajectory(trajectory)
            else:
                trajectory = None

            return trajectory
示例#3
0
文件: db.py 项目: zkid18/catalyst
        def get_checkpoint(self):
            try:
                checkpoint_obj = self._checkpoint_collection.find_one(
                    {"filename": "checkpoint"})
            except pymongo.errors.AutoReconnect:
                time.sleep(self._reconnect_timeout)
                return self.get_checkpoint()

            if checkpoint_obj is not None:
                checkpoint = checkpoint_obj.read().decode("ascii")
                self._epoch = checkpoint_obj.epoch
                checkpoint = unpack(checkpoint)
            else:
                checkpoint = None
            return checkpoint
示例#4
0
文件: db.py 项目: zkid18/catalyst
        def get_trajectory(self, index=None) -> Trajectory:
            index = index if index is not None else self._index
            trajectory = self._server.lindex("trajectories", index)
            if trajectory is not None:
                self._index = index + 1

                trajectory = unpack(trajectory)
                trajectory, trajectory_epoch = trajectory[
                    "trajectory"], trajectory["epoch"]
                if self._sync_epoch and self._epoch != trajectory_epoch:
                    trajectory = None
                else:
                    trajectory = dict2structed_trajectory(trajectory)

            return trajectory
示例#5
0
def main(args, _=None):
    db = Redis(host=args.host, port=args.port)
    redis_len = db.llen("trajectories") - 1
    trajectories = []

    for i in tqdm(range(args.start_from, redis_len)):
        trajectory = db.lindex("trajectories", i)

        if args.min_reward is not None:
            trajectory = utils.unpack(trajectory)
            if sum(trajectory["trajectory"][-2]) > args.min_reward:
                trajectory = utils.pack(trajectory)
                trajectories.append(trajectory)
        else:
            trajectories.append(trajectory)

        if args.chunk_size is not None \
                and (i - args.start_from) % args.chunk_size == 0:
            with open(args.out_pkl.format(suffix=i), "wb") as fout:
                pickle.dump(trajectories, fout)
            trajectories = []

    with open(args.out_pkl.format(suffix=i), "wb") as fout:
        pickle.dump(trajectories, fout)