コード例 #1
0
class ReplayMemory:
    def __init__(self, max_memory=1000):
        self.max_memory = max_memory
        self.memory = SegmentTree(max_memory)
        self._count = 0

    @property
    def count(self):
        return self._count

    def add_memory(self, state_input, best_action, reward, done,
                   next_state_input, td):
        data = [state_input, best_action, reward, done, next_state_input]

        self.memory.add(td, data)

        if self._count <= self.max_memory:
            self._count += 1

    def get_memory(self, batch_size):
        segment = self.memory.total / batch_size

        batch_tree_index = []
        tds = []
        batch = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            segment = random.uniform(a, b)
            tree_index, td, data = self.memory.get(segment)
            batch_tree_index.append(tree_index)
            tds.append(td)
            batch.append(data)

        return batch_tree_index, tds, batch

    def update_memory(self, tree_indexes, tds):
        for i in range(len(tree_indexes)):
            self.memory.update(tree_indexes[i], tds[i])
コード例 #2
0
ファイル: memory.py プロジェクト: hknozturk/Lunarlander
class PrioritizedReplayMemory:
    def __init__(self, args, capacity):
        self.capacity = capacity
        self.discount = args.gamma
        self.priority_weight = args.priority_weight
        self.priority_exponent = args.priority_exponent
        self.absolute_error_upper = args.absolute_error_upper
        self.t = 0  # Internal episode timestep counter
        self.tree = SegmentTree(
            capacity
        )  # Store experiences in a wrap-around cyclic buffer within a sum tree for querying priorities
        self.priority_weight_increase = (1 -
                                         args.priority_weight) / self.capacity

    # Adds state and action at time t, reward and done at time t + 1
    def append(self, state, action, reward, next_state, done):
        self.tree.append(
            Experience(state, action, reward, next_state, done),
            self.tree.max)  # Store new transition with maximum priority
        self.t = 0 if done else self.t + 1  # Start new episodes with t = 0

    def _get_sample_from_segment(self, segment, i):
        valid = False
        while not valid:
            sample = np.random.uniform(
                i * segment, (i + 1) *
                segment)  # Uniformly sample an element from within a segment
            prob, idx, tree_idx = self.tree.find(
                sample
            )  # Retrieve sample from tree with un-normalised probability
            # Resample if transition straddled current index or probability 0
            if prob != 0:
                valid = True  # Note that conditions are valid but extra conservative around buffer index 0

        experience = self.tree.get(idx)

        return prob, idx, tree_idx, experience

    def sample(self, batch_size):
        self.priority_weight = min(
            self.priority_weight + self.priority_weight_increase, 1)
        p_total = self.tree.total(
        )  # Retrieve sum of all priorities (used to create a normalised probability distribution)
        segment = p_total / batch_size  # Batch size number of segments, based on sum over all probabilities

        batch = [
            self._get_sample_from_segment(segment, i)
            for i in range(batch_size)
        ]  # Get batch of valid samples
        probs, idxs, tree_idxs, experiences = zip(*batch)

        states = torch.from_numpy(
            np.vstack([exp.state for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)
        actions = torch.from_numpy(
            np.vstack([exp.action for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.long)
        rewards = torch.from_numpy(
            np.vstack([exp.reward for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)
        next_states = torch.from_numpy(
            np.vstack([
                exp.next_state for exp in experiences if exp is not None
            ])).to(device=device).to(dtype=torch.float32)
        dones = torch.from_numpy(
            np.vstack([exp.done for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)

        probs = np.array(
            probs,
            dtype=np.float32) / p_total  # Calculate normalised probabilities
        capacity = self.capacity if self.tree.full else self.tree.index
        weights = (
            capacity * probs
        )**-self.priority_weight  # Compute importance-sampling weights w
        weights = torch.tensor(
            weights / weights.max(), dtype=torch.float32, device=device
        )  # Normalise by max importance-sampling weight from batch
        return tree_idxs, states, actions, rewards, next_states, dones, weights

    def update_priorities(self, idxs, priorities):
        # priorities = errors
        clipped_errors = np.minimum(priorities, self.absolute_error_upper)
        clipped_errors = np.power(clipped_errors, self.priority_exponent)
        for idx, priority in zip(idxs, clipped_errors):
            self.tree.update(idx, priority)

    def __len__(self):
        return len(self.tree)