Esempio n. 1
0
    def estimate(self, batch):
        self.check_can_estimate_for(batch)

        rewards, old_prob = batch["rewards"], batch["action_prob"]
        new_prob = self.action_prob(batch)

        # calculate importance ratios
        p = []
        for t in range(batch.count - 1):
            if t == 0:
                pt_prev = 1.0
            else:
                pt_prev = p[t - 1]
            p.append(pt_prev * new_prob[t] / old_prob[t])

        # calculate stepwise IS estimate
        V_prev, V_step_IS = 0.0, 0.0
        for t in range(batch.count - 1):
            V_prev += rewards[t] * self.gamma**t
            V_step_IS += p[t] * rewards[t] * self.gamma**t

        estimation = OffPolicyEstimate(
            "is", {
                "V_prev": V_prev,
                "V_step_IS": V_step_IS,
                "V_gain_est": V_step_IS / max(1e-8, V_prev),
            })
        return estimation
    def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
        self.check_can_estimate_for(batch)

        rewards, old_prob = batch["rewards"], batch["action_prob"]
        new_prob = self.action_log_likelihood(batch)

        # calculate importance ratios
        p = []
        for t in range(batch.count):
            if t == 0:
                pt_prev = 1.0
            else:
                pt_prev = p[t - 1]
            p.append(pt_prev * new_prob[t] / old_prob[t])
        for t, v in enumerate(p):
            if t >= len(self.filter_values):
                self.filter_values.append(v)
                self.filter_counts.append(1.0)
            else:
                self.filter_values[t] += v
                self.filter_counts[t] += 1.0

        # calculate stepwise weighted IS estimate
        V_prev, V_step_WIS = 0.0, 0.0
        for t in range(batch.count):
            V_prev += rewards[t] * self.gamma**t
            w_t = self.filter_values[t] / self.filter_counts[t]
            V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t

        estimation = OffPolicyEstimate(
            "weighted_importance_sampling",
            {
                "V_prev": V_prev,
                "V_step_WIS": V_step_WIS,
                "V_gain_est": V_step_WIS / max(1e-8, V_prev),
            },
        )
        return estimation