class Storage(Callback): """Storage train examples. Args: out_path (str): Path to output hdf5 file. exp_replay_size (int): How many transitions to keep at max. If this number is exceeded, oldest transition is dropped. gamma (float): Discount factor. """ def __init__(self, out_path, exp_replay_size, gamma): self.small_bag = deque() self.big_bag = deque() self.out_path = out_path self.exp_replay_size = exp_replay_size self.gamma = gamma self._recent_action_probs = None def on_action_planned(self, step, logits, info): # Proportional without temperature self._recent_action_probs = logits / np.sum(logits) def on_step_taken(self, step, transition, info): # NOTE: We never pass terminal state (it would be next_state), so NN can't learn directly # what is the value of terminal/end state. self.small_bag.append(self._create_small_package(transition)) if len(self.small_bag) > self.exp_replay_size: self.small_bag.popleft() if transition.is_terminal: return_t = 0 for state, reward, mcts_pi in reversed(self.small_bag): return_t = reward + self.gamma * return_t self.big_bag.append((state, mcts_pi, return_t)) if len(self.big_bag) > self.exp_replay_size: self.big_bag.popleft() self.small_bag.clear() def store(self): path = self.out_path folder = os.path.dirname(path) if not os.path.exists(folder): log.warning("Examples store directory does not exist! Creating directory %s", folder) os.makedirs(folder) with open(path, "wb+") as f: Pickler(f).dump(self.big_bag) def load(self): path = self.out_path if not os.path.isfile(path): log.warning("File with train examples was not found.") else: log.info("File with train examples found. Reading it.") with open(path, "rb") as f: self.big_bag = Unpickler(f).load() # Prune dataset if too big while len(self.big_bag) > self.exp_replay_size: self.big_bag.popleft() @property def metrics(self): logs = {"# samples": len(self.big_bag)} return logs def _create_small_package(self, transition): return (transition.state, transition.reward, self._recent_action_probs)