Example #1
0
    def load_checkpoint(self,
                        *,
                        filepath: str = None,
                        db_server: DBSpec = None):
        if filepath is not None:
            checkpoint = UtilsFactory.load_checkpoint(filepath)
            weights = checkpoint[f"{self._sampler_weight_mode}_state_dict"]
            self.agent.load_state_dict(weights)
        elif db_server is not None:
            while not db_server.get_sample_flag():
                time.sleep(1.0)
            weights = db_server.load_weights(prefix=self._sampler_weight_mode)
            weights = {k: self._to_tensor(v) for k, v in weights.items()}
            self.agent.load_state_dict(weights)
        else:
            raise NotImplementedError

        self.agent.to(self._device)
        self.agent.eval()
Example #2
0
def _db2buffer_loop(db_server: DBSpec, buffer: OffpolicyReplayBuffer):
    trajectory = None
    while True:
        if trajectory is None:
            trajectory = db_server.get_trajectory()

        if trajectory is not None:
            if buffer.push_trajectory(trajectory):
                trajectory = None
            else:
                time.sleep(1.0)
        else:
            time.sleep(1.0)
Example #3
0
def _db2queue_loop(db_server: DBSpec, queue: mp.Queue, max_size: int):
    while True:
        try:
            need_more = queue.qsize() < max_size
        except NotImplementedError:  # MacOS qsize issue (no sem_getvalue)
            need_more = True

        if need_more:
            trajectory = db_server.get_trajectory()
            if trajectory is not None:
                queue.put(trajectory, block=True, timeout=1.0)
            else:
                time.sleep(1.0)
        else:
            time.sleep(1.0)