def _compute_estimates(self, input: RLEstimatorInput) -> EstimatorResults: results = EstimatorResults() estimate = self._mdps_value(input.log, input.gamma) results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value( input.gamma, input.log, input.ground_truth), )) return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): estimate = input.value_function(state) if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() n = len(input.log) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log)) ws = self._calc_weights( n, horizon, zip_longest(*input.log), input.target_policy ) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*input.log), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: assert input.value_function is not None qs[i, j] = input.value_function(t.last_state, t.action) vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) * discount).sum().item() results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value(input.gamma, input.log, input.ground_truth), ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): n = len(mdps) horizon = len( reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) ws = self._calc_weights(n, horizon, zip_longest(*mdps), input.target_policy) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon, ), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*mdps), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: assert input.value_function is not None qs[i, j] = input.value_function(t.last_state, t.action) assert input.value_function is not None vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) * discount).sum().item() if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() estimate = self._estimate_value(input.gamma, input.log, input.value_function) if input.ground_truth is not None: gt = self._estimate_value(input.gamma, input.log, input.ground_truth) results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else gt, )) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() n = len(input.log) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log)) weights = self._calc_weights( n, horizon, zip_longest(*input.log), input.target_policy ) discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rewards = torch.zeros((n, horizon)) j = 0 for ts in zip_longest(*input.log): i = 0 for t in ts: if t is not None: rewards[i, j] = t.reward i += 1 j += 1 rewards = rewards.to(device=self._device) estimate = weights.mul(rewards).sum(0).mul(discount).sum().item() results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value(input.gamma, input.log, input.ground_truth), ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): n = len(mdps) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) weights = self._calc_weights( n, horizon, zip_longest(*mdps), input.target_policy ) discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rewards = torch.zeros((n, horizon)) j = 0 for ts in zip_longest(*mdps): i = 0 for t in ts: if t is not None: rewards[i, j] = t.reward i += 1 j += 1 rewards = rewards.to(device=self._device) estimate = weights.mul(rewards).sum(0).mul(discount).sum().item() if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult( self._log_reward(input.gamma, mdps), estimate, ground_truth ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() num_resamples = kwargs[ "num_resamples"] if "num_resamples" in kwargs else 200 loss_threhold = (kwargs["loss_threhold"] if "loss_threhold" in kwargs else 0.00001) lr = kwargs["lr"] if "lr" in kwargs else 0.0001 logging.info(f" params: num_resamples[{num_resamples}], " f"loss_threshold[{loss_threhold}], " f"lr[{lr}]") for state, mdps in input.log.items(): n = len(mdps) horizon = len( reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) ws = self._calc_weights(n, horizon, zip_longest(*mdps), input.target_policy) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon, ), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*mdps), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: qs[i, j] = input.value_function(t.last_state, t.action) vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) wdrs = ((ws * (rs - qs) + last_ws * vs) * discount).cumsum(1) wdr = wdrs[:, -1].sum(0) next_vs = torch.zeros((n, horizon), device=self._device) next_vs[:, :-1] = vs[:, 1:] gs = wdrs + ws * next_vs * discount gs_normal = gs.sub(torch.mean(gs, 0)) omiga = n * torch.einsum("ij,ik->jk", gs_normal, gs_normal) / (n - 1.0) resample_wdrs = torch.zeros((num_resamples, )) for i in range(num_resamples): samples = random.choices(range(n), k=n) sws = ws[samples, :] last_sws = last_ws[samples, :] srs = rs[samples, :] svs = vs[samples, :] sqs = qs[samples, :] resample_wdrs[i] = (((sws * (srs - sqs) + last_sws * svs).sum(0) * discount).sum().item()) resample_wdrs, _ = resample_wdrs.to(device=self._device).sort(0) lb = torch.min(wdr, resample_wdrs[int(round(0.05 * num_resamples))]) ub = torch.max(wdr, resample_wdrs[int(round(0.95 * num_resamples)) - 1]) b = torch.tensor( list( map( lambda a: a - ub if a > ub else (a - lb if a < lb else 0.0), gs.sum(0), )), device=self._device, ) b.unsqueeze_(0) bb = b * b.t() cov = omiga + bb # x = torch.rand((1, horizon), device=self.device, requires_grad=True) x = torch.zeros((1, horizon), device=self._device, requires_grad=True) # using SGD to find min x optimizer = torch.optim.SGD([x], lr=lr) last_y = 0.0 for i in range(100): x = torch.nn.functional.softmax(x, dim=1) y = torch.mm(torch.mm(x, cov), x.t()) if abs(y.item() - last_y) < loss_threhold: print(f"{i}: {last_y} -> {y.item()}") break last_y = y.item() optimizer.zero_grad() y.backward(retain_graph=True) optimizer.step() x = torch.nn.functional.softmax(x, dim=1) estimate = torch.mm(x, gs.sum(0, keepdim=True).t()) if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results