def load_checkpoint(self, *, filepath: str = None, db_server: DBSpec = None): if filepath is not None: checkpoint = UtilsFactory.load_checkpoint(filepath) weights = checkpoint[f"{self._sampler_weight_mode}_state_dict"] self.agent.load_state_dict(weights) elif db_server is not None: while not db_server.get_sample_flag(): time.sleep(1.0) weights = db_server.load_weights(prefix=self._sampler_weight_mode) weights = {k: self._to_tensor(v) for k, v in weights.items()} self.agent.load_state_dict(weights) else: raise NotImplementedError self.agent.to(self._device) self.agent.eval()
def _db2buffer_loop(db_server: DBSpec, buffer: OffpolicyReplayBuffer): trajectory = None while True: if trajectory is None: trajectory = db_server.get_trajectory() if trajectory is not None: if buffer.push_trajectory(trajectory): trajectory = None else: time.sleep(1.0) else: time.sleep(1.0)
def _db2queue_loop(db_server: DBSpec, queue: mp.Queue, max_size: int): while True: try: need_more = queue.qsize() < max_size except NotImplementedError: # MacOS qsize issue (no sem_getvalue) need_more = True if need_more: trajectory = db_server.get_trajectory() if trajectory is not None: queue.put(trajectory, block=True, timeout=1.0) else: time.sleep(1.0) else: time.sleep(1.0)