Beispiel #1
0
 def _compute_estimates(self, input: RLEstimatorInput) -> EstimatorResults:
     results = EstimatorResults()
     estimate = self._mdps_value(input.log, input.gamma)
     results.append(
         EstimatorResult(
             self._log_reward(input.gamma, input.log),
             estimate,
             None if input.ground_truth is None else self._estimate_value(
                 input.gamma, input.log, input.ground_truth),
         ))
     return results
Beispiel #2
0
 def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
     assert input.value_function is not None
     logging.info(f"{self}: start evaluating")
     stime = time.process_time()
     results = EstimatorResults()
     for state, mdps in input.log.items():
         estimate = input.value_function(state)
         if input.ground_truth is not None:
             ground_truth = input.ground_truth(state)
         else:
             ground_truth = None
         results.append(
             EstimatorResult(self._log_reward(input.gamma, mdps), estimate,
                             ground_truth))
     logging.info(f"{self}: finishing evaluating["
                  f"process_time={time.process_time() - stime}]")
     return results
Beispiel #3
0
    def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
        # kwargs is part of the function signature, so to satisfy pyre it must be included
        logging.info(f"{self}: start evaluating")
        stime = time.process_time()
        results = EstimatorResults()

        n = len(input.log)
        horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log))
        ws = self._calc_weights(
            n, horizon, zip_longest(*input.log), input.target_policy
        )
        last_ws = torch.zeros((n, horizon), device=self._device)
        last_ws[:, 0] = 1.0 / n
        last_ws[:, 1:] = ws[:, :-1]
        discount = torch.full((horizon,), input.gamma, device=self._device)
        discount[0] = 1.0
        discount = discount.cumprod(0)
        rs = torch.zeros((n, horizon))
        vs = torch.zeros((n, horizon))
        qs = torch.zeros((n, horizon))
        for ts, j in zip(zip_longest(*input.log), count()):
            for t, i in zip(ts, count()):
                if t is not None and t.action is not None:
                    assert input.value_function is not None
                    qs[i, j] = input.value_function(t.last_state, t.action)
                    vs[i, j] = input.value_function(t.last_state)
                    rs[i, j] = t.reward
        vs = vs.to(device=self._device)
        qs = qs.to(device=self._device)
        rs = rs.to(device=self._device)
        estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) * discount).sum().item()
        results.append(
            EstimatorResult(
                self._log_reward(input.gamma, input.log),
                estimate,
                None
                if input.ground_truth is None
                else self._estimate_value(input.gamma, input.log, input.ground_truth),
            )
        )
        logging.info(
            f"{self}: finishing evaluating["
            f"process_time={time.process_time() - stime}]"
        )
        return results
Beispiel #4
0
 def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
     logging.info(f"{self}: start evaluating")
     stime = time.process_time()
     results = EstimatorResults()
     for state, mdps in input.log.items():
         n = len(mdps)
         horizon = len(
             reduce(lambda a, b: a if len(a) > len(b) else b, mdps))
         ws = self._calc_weights(n, horizon, zip_longest(*mdps),
                                 input.target_policy)
         last_ws = torch.zeros((n, horizon), device=self._device)
         last_ws[:, 0] = 1.0 / n
         last_ws[:, 1:] = ws[:, :-1]
         discount = torch.full((horizon, ),
                               input.gamma,
                               device=self._device)
         discount[0] = 1.0
         discount = discount.cumprod(0)
         rs = torch.zeros((n, horizon))
         vs = torch.zeros((n, horizon))
         qs = torch.zeros((n, horizon))
         for ts, j in zip(zip_longest(*mdps), count()):
             for t, i in zip(ts, count()):
                 if t is not None and t.action is not None:
                     assert input.value_function is not None
                     qs[i, j] = input.value_function(t.last_state, t.action)
                     assert input.value_function is not None
                     vs[i, j] = input.value_function(t.last_state)
                     rs[i, j] = t.reward
         vs = vs.to(device=self._device)
         qs = qs.to(device=self._device)
         rs = rs.to(device=self._device)
         estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) *
                     discount).sum().item()
         if input.ground_truth is not None:
             ground_truth = input.ground_truth(state)
         else:
             ground_truth = None
         results.append(
             EstimatorResult(self._log_reward(input.gamma, mdps), estimate,
                             ground_truth))
     logging.info(f"{self}: finishing evaluating["
                  f"process_time={time.process_time() - stime}]")
     return results
    def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
        # kwargs is part of the function signature, so to satisfy pyre it must be included
        assert input.value_function is not None
        logging.info(f"{self}: start evaluating")
        stime = time.process_time()
        results = EstimatorResults()

        estimate = self._estimate_value(input.gamma, input.log,
                                        input.value_function)
        if input.ground_truth is not None:
            gt = self._estimate_value(input.gamma, input.log,
                                      input.ground_truth)
        results.append(
            EstimatorResult(
                self._log_reward(input.gamma, input.log),
                estimate,
                None if input.ground_truth is None else gt,
            ))
        logging.info(f"{self}: finishing evaluating["
                     f"process_time={time.process_time() - stime}]")
        return results
Beispiel #6
0
    def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
        # kwargs is part of the function signature, so to satisfy pyre it must be included
        logging.info(f"{self}: start evaluating")
        stime = time.process_time()
        results = EstimatorResults()

        n = len(input.log)
        horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log))
        weights = self._calc_weights(
            n, horizon, zip_longest(*input.log), input.target_policy
        )
        discount = torch.full((horizon,), input.gamma, device=self._device)
        discount[0] = 1.0
        discount = discount.cumprod(0)
        rewards = torch.zeros((n, horizon))
        j = 0
        for ts in zip_longest(*input.log):
            i = 0
            for t in ts:
                if t is not None:
                    rewards[i, j] = t.reward
                i += 1
            j += 1
        rewards = rewards.to(device=self._device)
        estimate = weights.mul(rewards).sum(0).mul(discount).sum().item()

        results.append(
            EstimatorResult(
                self._log_reward(input.gamma, input.log),
                estimate,
                None
                if input.ground_truth is None
                else self._estimate_value(input.gamma, input.log, input.ground_truth),
            )
        )
        logging.info(
            f"{self}: finishing evaluating["
            f"process_time={time.process_time() - stime}]"
        )
        return results
 def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
     logging.info(f"{self}: start evaluating")
     stime = time.process_time()
     results = EstimatorResults()
     for state, mdps in input.log.items():
         n = len(mdps)
         horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, mdps))
         weights = self._calc_weights(
             n, horizon, zip_longest(*mdps), input.target_policy
         )
         discount = torch.full((horizon,), input.gamma, device=self._device)
         discount[0] = 1.0
         discount = discount.cumprod(0)
         rewards = torch.zeros((n, horizon))
         j = 0
         for ts in zip_longest(*mdps):
             i = 0
             for t in ts:
                 if t is not None:
                     rewards[i, j] = t.reward
                 i += 1
             j += 1
         rewards = rewards.to(device=self._device)
         estimate = weights.mul(rewards).sum(0).mul(discount).sum().item()
         if input.ground_truth is not None:
             ground_truth = input.ground_truth(state)
         else:
             ground_truth = None
         results.append(
             EstimatorResult(
                 self._log_reward(input.gamma, mdps), estimate, ground_truth
             )
         )
     logging.info(
         f"{self}: finishing evaluating["
         f"process_time={time.process_time() - stime}]"
     )
     return results
Beispiel #8
0
 def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
     assert input.value_function is not None
     logging.info(f"{self}: start evaluating")
     stime = time.process_time()
     results = EstimatorResults()
     num_resamples = kwargs[
         "num_resamples"] if "num_resamples" in kwargs else 200
     loss_threhold = (kwargs["loss_threhold"]
                      if "loss_threhold" in kwargs else 0.00001)
     lr = kwargs["lr"] if "lr" in kwargs else 0.0001
     logging.info(f"  params: num_resamples[{num_resamples}], "
                  f"loss_threshold[{loss_threhold}], "
                  f"lr[{lr}]")
     for state, mdps in input.log.items():
         n = len(mdps)
         horizon = len(
             reduce(lambda a, b: a if len(a) > len(b) else b, mdps))
         ws = self._calc_weights(n, horizon, zip_longest(*mdps),
                                 input.target_policy)
         last_ws = torch.zeros((n, horizon), device=self._device)
         last_ws[:, 0] = 1.0 / n
         last_ws[:, 1:] = ws[:, :-1]
         discount = torch.full((horizon, ),
                               input.gamma,
                               device=self._device)
         discount[0] = 1.0
         discount = discount.cumprod(0)
         rs = torch.zeros((n, horizon))
         vs = torch.zeros((n, horizon))
         qs = torch.zeros((n, horizon))
         for ts, j in zip(zip_longest(*mdps), count()):
             for t, i in zip(ts, count()):
                 if t is not None and t.action is not None:
                     qs[i, j] = input.value_function(t.last_state, t.action)
                     vs[i, j] = input.value_function(t.last_state)
                     rs[i, j] = t.reward
         vs = vs.to(device=self._device)
         qs = qs.to(device=self._device)
         rs = rs.to(device=self._device)
         wdrs = ((ws * (rs - qs) + last_ws * vs) * discount).cumsum(1)
         wdr = wdrs[:, -1].sum(0)
         next_vs = torch.zeros((n, horizon), device=self._device)
         next_vs[:, :-1] = vs[:, 1:]
         gs = wdrs + ws * next_vs * discount
         gs_normal = gs.sub(torch.mean(gs, 0))
         omiga = n * torch.einsum("ij,ik->jk", gs_normal,
                                  gs_normal) / (n - 1.0)
         resample_wdrs = torch.zeros((num_resamples, ))
         for i in range(num_resamples):
             samples = random.choices(range(n), k=n)
             sws = ws[samples, :]
             last_sws = last_ws[samples, :]
             srs = rs[samples, :]
             svs = vs[samples, :]
             sqs = qs[samples, :]
             resample_wdrs[i] = (((sws *
                                   (srs - sqs) + last_sws * svs).sum(0) *
                                  discount).sum().item())
         resample_wdrs, _ = resample_wdrs.to(device=self._device).sort(0)
         lb = torch.min(wdr,
                        resample_wdrs[int(round(0.05 * num_resamples))])
         ub = torch.max(wdr,
                        resample_wdrs[int(round(0.95 * num_resamples)) - 1])
         b = torch.tensor(
             list(
                 map(
                     lambda a: a - ub if a > ub else (a - lb
                                                      if a < lb else 0.0),
                     gs.sum(0),
                 )),
             device=self._device,
         )
         b.unsqueeze_(0)
         bb = b * b.t()
         cov = omiga + bb
         # x = torch.rand((1, horizon), device=self.device, requires_grad=True)
         x = torch.zeros((1, horizon),
                         device=self._device,
                         requires_grad=True)
         # using SGD to find min x
         optimizer = torch.optim.SGD([x], lr=lr)
         last_y = 0.0
         for i in range(100):
             x = torch.nn.functional.softmax(x, dim=1)
             y = torch.mm(torch.mm(x, cov), x.t())
             if abs(y.item() - last_y) < loss_threhold:
                 print(f"{i}: {last_y} -> {y.item()}")
                 break
             last_y = y.item()
             optimizer.zero_grad()
             y.backward(retain_graph=True)
             optimizer.step()
         x = torch.nn.functional.softmax(x, dim=1)
         estimate = torch.mm(x, gs.sum(0, keepdim=True).t())
         if input.ground_truth is not None:
             ground_truth = input.ground_truth(state)
         else:
             ground_truth = None
         results.append(
             EstimatorResult(self._log_reward(input.gamma, mdps), estimate,
                             ground_truth))
     logging.info(f"{self}: finishing evaluating["
                  f"process_time={time.process_time() - stime}]")
     return results