def evaluate(self, input: BanditsEstimatorInput, **kwargs) -> Optional[EstimatorResult]: if not self._train_model(input.samples, 0.8) and not input.has_model_outputs: return None log_avg = RunningAverage() tgt_avg = RunningAverage() tgt_vals = [] logged_vals = [] gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) logged_vals.append(sample.log_reward) _, tgt_reward = self._calc_dm_reward(input.action_space, sample) tgt_avg.add(tgt_reward) tgt_vals.append(tgt_reward) gt_avg.add(sample.ground_truth_reward) ( tgt_score, tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(torch.tensor(tgt_vals), torch.tensor(logged_vals), tgt_avg.average) return EstimatorResult( log_avg.average, tgt_score, gt_avg.average, tgt_avg.count, tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, )
def _evaluate( self, input: BanditsEstimatorInput, train_samples: Sequence[LogSample], eval_samples: Sequence[LogSample], force_train: bool = False, **kwargs, ) -> Optional[EstimatorResult]: logger.info("OPE DR Evaluating") self._train_model(train_samples, force_train) log_avg = RunningAverage() tgt_avg = RunningAverage() tgt_vals = [] gt_avg = RunningAverage() for sample in eval_samples: log_avg.add(sample.log_reward) dm_action_reward, dm_scores, dm_probs = self._calc_dm_reward( input.action_space, sample ) dm_reward = torch.dot(dm_scores.reshape(-1), dm_probs.reshape(-1)).item() tgt_result = 0.0 weight = 0.0 if sample.log_action.value is not None: weight = ( 0.0 if sample.log_action_probabilities[sample.log_action] < PROPENSITY_THRESHOLD else sample.tgt_action_probabilities[sample.log_action] / sample.log_action_probabilities[sample.log_action] ) weight = self._weight_clamper(weight) assert dm_action_reward is not None assert dm_reward is not None tgt_result += ( sample.log_reward - dm_action_reward ) * weight + dm_reward else: tgt_result = dm_reward tgt_avg.add(tgt_result) tgt_vals.append(tgt_result) gt_avg.add(sample.ground_truth_reward) ( tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(torch.tensor(tgt_vals), log_avg.average) return EstimatorResult( log_reward=log_avg.average, estimated_reward=tgt_avg.average, ground_truth_reward=gt_avg.average, estimated_weight=tgt_avg.count, estimated_reward_normalized=tgt_score_normalized, estimated_reward_std_error=tgt_std_err, estimated_reward_normalized_std_error=tgt_std_err_normalized, )
def _compute_estimates(self, input: RLEstimatorInput) -> EstimatorResults: results = EstimatorResults() estimate = self._mdps_value(input.log, input.gamma) results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value( input.gamma, input.log, input.ground_truth), )) return results
def evaluate( self, input: BanditsEstimatorInput, **kwargs ) -> Optional[EstimatorResult]: if input.has_model_outputs: return self._evaluate( input, input.samples, input.samples, force_train=True, **kwargs ) log_avg = RunningAverage() gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) gt_avg.add(sample.ground_truth_reward) # 2-fold cross "validation" as used by https://arxiv.org/pdf/1612.01205.pdf shuffled = list(input.samples) np.random.shuffle(shuffled) lower_half = shuffled[: len(shuffled) // 2] upper_half = shuffled[len(shuffled) // 2 :] er_lower = self._evaluate( input, lower_half, upper_half, force_train=True, **kwargs ) er_upper = self._evaluate( input, upper_half, lower_half, force_train=True, **kwargs ) if er_lower is None or er_upper is None: return None return EstimatorResult( log_reward=log_avg.average, estimated_reward=( (er_lower.estimated_reward + er_upper.estimated_reward) / 2 ), estimated_reward_normalized=( DMEstimator._calc_optional_avg( er_lower.estimated_reward_normalized, er_upper.estimated_reward_normalized, ) ), estimated_reward_normalized_std_error=( DMEstimator._calc_optional_avg( er_lower.estimated_reward_normalized_std_error, er_upper.estimated_reward_normalized_std_error, ) ), estimated_reward_std_error=( DMEstimator._calc_optional_avg( er_lower.estimated_reward_std_error, er_upper.estimated_reward_std_error, ) ), ground_truth_reward=gt_avg.average, )
def evaluate(self, input: BanditsEstimatorInput, **kwargs) -> Optional[EstimatorResult]: log_avg = RunningAverage() tgt_avg = RunningAverage() acc_weight = RunningAverage() gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) weight = (sample.tgt_action_probabilities[sample.log_action] / sample.log_action_probabilities[sample.log_action]) weight = self._weight_clamper(weight) tgt_avg.add(sample.log_reward * weight) acc_weight.add(weight) gt_avg.add(sample.ground_truth_reward) if self._weighted: return EstimatorResult( log_avg.average, tgt_avg.total / acc_weight.total, gt_avg.average, acc_weight.average, ) else: return EstimatorResult(log_avg.average, tgt_avg.average, gt_avg.average, tgt_avg.count)
def evaluate(self, input: BanditsEstimatorInput, **kwargs) -> Optional[EstimatorResult]: logger = Estimator.logger() if not self._train_model(input.samples, 0.8, logger): return None log_avg = RunningAverage() tgt_avg = RunningAverage() gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) _, tgt_reward = self._calc_dm_reward(input.action_space, sample) tgt_avg.add(tgt_reward) gt_avg.add(sample.ground_truth_reward) return EstimatorResult(log_avg.average, tgt_avg.average, gt_avg.average, tgt_avg.count)
def evaluate(self, input: BanditsEstimatorInput, **kwargs) -> Optional[EstimatorResult]: self._train_model(input.samples, 0.8) log_avg = RunningAverage() logged_vals = [] tgt_avg = RunningAverage() tgt_vals = [] gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) logged_vals.append(sample.log_reward) dm_action_reward, dm_reward = self._calc_dm_reward( input.action_space, sample) tgt_result = 0.0 weight = 0.0 if sample.log_action is not None: weight = (0.0 if sample.log_action_probabilities[sample.log_action] < PROPENSITY_THRESHOLD else sample.tgt_action_probabilities[sample.log_action] / sample.log_action_probabilities[sample.log_action]) weight = self._weight_clamper(weight) assert dm_action_reward is not None assert dm_reward is not None tgt_result += (sample.log_reward - dm_action_reward) * weight + dm_reward else: tgt_result = dm_reward tgt_avg.add(tgt_result) tgt_vals.append(tgt_result) gt_avg.add(sample.ground_truth_reward) ( tgt_score, tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(torch.tensor(tgt_vals), torch.tensor(logged_vals), tgt_avg.average) return EstimatorResult( log_avg.average, tgt_score, gt_avg.average, tgt_avg.count, tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, )
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): estimate = input.value_function(state) if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() n = len(input.log) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log)) ws = self._calc_weights( n, horizon, zip_longest(*input.log), input.target_policy ) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*input.log), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: assert input.value_function is not None qs[i, j] = input.value_function(t.last_state, t.action) vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) * discount).sum().item() results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value(input.gamma, input.log, input.ground_truth), ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def evaluate( self, input: BanditsEstimatorInput, **kwargs ) -> Optional[EstimatorResult]: logger.info("OPE IPS Evaluating") log_avg = RunningAverage() logged_vals = [] tgt_avg = RunningAverage() tgt_vals = [] acc_weight = RunningAverage() gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) logged_vals.append(sample.log_reward) weight = 0.0 tgt_result = 0.0 if sample.log_action.value is not None: weight = ( 0.0 if sample.log_action_probabilities[sample.log_action] < PROPENSITY_THRESHOLD else sample.tgt_action_probabilities[sample.log_action] / sample.log_action_probabilities[sample.log_action] ) weight = self._weight_clamper(weight) tgt_result = sample.log_reward * weight tgt_avg.add(tgt_result) tgt_vals.append(tgt_result) acc_weight.add(weight) gt_avg.add(sample.ground_truth_reward) ( tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(torch.tensor(tgt_vals), log_avg.average) return EstimatorResult( log_reward=log_avg.average, estimated_reward=tgt_avg.average if not self._weighted else tgt_avg.average / acc_weight.total, ground_truth_reward=gt_avg.average, estimated_weight=tgt_avg.count, estimated_reward_normalized=tgt_score_normalized, estimated_reward_std_error=tgt_std_err, estimated_reward_normalized_std_error=tgt_std_err_normalized, )
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): n = len(mdps) horizon = len( reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) ws = self._calc_weights(n, horizon, zip_longest(*mdps), input.target_policy) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon, ), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*mdps), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: assert input.value_function is not None qs[i, j] = input.value_function(t.last_state, t.action) assert input.value_function is not None vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) estimate = ((ws * (rs - qs) + last_ws * vs).sum(0) * discount).sum().item() if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: BanditsEstimatorInput, **kwargs) -> Optional[EstimatorResult]: logger = Estimator.logger() self._train_model(input.samples, 0.8, logger) log_avg = RunningAverage() tgt_avg = RunningAverage() gt_avg = RunningAverage() for sample in input.samples: log_avg.add(sample.log_reward) weight = (sample.tgt_action_probabilities[sample.log_action] / sample.log_action_probabilities[sample.log_action]) weight = self._weight_clamper(weight) dm_action_reward, dm_reward = self._calc_dm_reward( input.action_space, sample) tgt_avg.add((sample.log_reward - dm_action_reward) * weight + dm_reward) gt_avg.add(sample.ground_truth_reward) return EstimatorResult(log_avg.average, tgt_avg.average, gt_avg.average, tgt_avg.count)
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() estimate = self._estimate_value(input.gamma, input.log, input.value_function) if input.ground_truth is not None: gt = self._estimate_value(input.gamma, input.log, input.ground_truth) results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else gt, )) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: # kwargs is part of the function signature, so to satisfy pyre it must be included logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() n = len(input.log) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log)) weights = self._calc_weights( n, horizon, zip_longest(*input.log), input.target_policy ) discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rewards = torch.zeros((n, horizon)) j = 0 for ts in zip_longest(*input.log): i = 0 for t in ts: if t is not None: rewards[i, j] = t.reward i += 1 j += 1 rewards = rewards.to(device=self._device) estimate = weights.mul(rewards).sum(0).mul(discount).sum().item() results.append( EstimatorResult( self._log_reward(input.gamma, input.log), estimate, None if input.ground_truth is None else self._estimate_value(input.gamma, input.log, input.ground_truth), ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def _evaluate( self, input: BanditsEstimatorInput, train_samples: Sequence[LogSample], eval_samples: Sequence[LogSample], force_train: bool = False, **kwargs, ) -> Optional[EstimatorResult]: logger.info("OPE DM Evaluating") if ( not self._train_model(train_samples, force_train) and not input.has_model_outputs ): return None log_avg = RunningAverage() tgt_avg = RunningAverage() tgt_vals = [] gt_avg = RunningAverage() for sample in eval_samples: log_avg.add(sample.log_reward) _, tgt_scores, tgt_probs = self._calc_dm_reward(input.action_space, sample) tgt_reward = torch.dot(tgt_scores.reshape(-1), tgt_probs.reshape(-1)).item() tgt_avg.add(tgt_reward) tgt_vals.append(tgt_reward) gt_avg.add(sample.ground_truth_reward) ( tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(torch.tensor(tgt_vals), log_avg.average) return EstimatorResult( log_reward=log_avg.average, estimated_reward=tgt_avg.average, ground_truth_reward=gt_avg.average, estimated_weight=tgt_avg.count, estimated_reward_normalized=tgt_score_normalized, estimated_reward_std_error=tgt_std_err, estimated_reward_normalized_std_error=tgt_std_err_normalized, )
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() for state, mdps in input.log.items(): n = len(mdps) horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) weights = self._calc_weights( n, horizon, zip_longest(*mdps), input.target_policy ) discount = torch.full((horizon,), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rewards = torch.zeros((n, horizon)) j = 0 for ts in zip_longest(*mdps): i = 0 for t in ts: if t is not None: rewards[i, j] = t.reward i += 1 j += 1 rewards = rewards.to(device=self._device) estimate = weights.mul(rewards).sum(0).mul(discount).sum().item() if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult( self._log_reward(input.gamma, mdps), estimate, ground_truth ) ) logging.info( f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]" ) return results
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults: assert input.value_function is not None logging.info(f"{self}: start evaluating") stime = time.process_time() results = EstimatorResults() num_resamples = kwargs[ "num_resamples"] if "num_resamples" in kwargs else 200 loss_threhold = (kwargs["loss_threhold"] if "loss_threhold" in kwargs else 0.00001) lr = kwargs["lr"] if "lr" in kwargs else 0.0001 logging.info(f" params: num_resamples[{num_resamples}], " f"loss_threshold[{loss_threhold}], " f"lr[{lr}]") for state, mdps in input.log.items(): n = len(mdps) horizon = len( reduce(lambda a, b: a if len(a) > len(b) else b, mdps)) ws = self._calc_weights(n, horizon, zip_longest(*mdps), input.target_policy) last_ws = torch.zeros((n, horizon), device=self._device) last_ws[:, 0] = 1.0 / n last_ws[:, 1:] = ws[:, :-1] discount = torch.full((horizon, ), input.gamma, device=self._device) discount[0] = 1.0 discount = discount.cumprod(0) rs = torch.zeros((n, horizon)) vs = torch.zeros((n, horizon)) qs = torch.zeros((n, horizon)) for ts, j in zip(zip_longest(*mdps), count()): for t, i in zip(ts, count()): if t is not None and t.action is not None: qs[i, j] = input.value_function(t.last_state, t.action) vs[i, j] = input.value_function(t.last_state) rs[i, j] = t.reward vs = vs.to(device=self._device) qs = qs.to(device=self._device) rs = rs.to(device=self._device) wdrs = ((ws * (rs - qs) + last_ws * vs) * discount).cumsum(1) wdr = wdrs[:, -1].sum(0) next_vs = torch.zeros((n, horizon), device=self._device) next_vs[:, :-1] = vs[:, 1:] gs = wdrs + ws * next_vs * discount gs_normal = gs.sub(torch.mean(gs, 0)) omiga = n * torch.einsum("ij,ik->jk", gs_normal, gs_normal) / (n - 1.0) resample_wdrs = torch.zeros((num_resamples, )) for i in range(num_resamples): samples = random.choices(range(n), k=n) sws = ws[samples, :] last_sws = last_ws[samples, :] srs = rs[samples, :] svs = vs[samples, :] sqs = qs[samples, :] resample_wdrs[i] = (((sws * (srs - sqs) + last_sws * svs).sum(0) * discount).sum().item()) resample_wdrs, _ = resample_wdrs.to(device=self._device).sort(0) lb = torch.min(wdr, resample_wdrs[int(round(0.05 * num_resamples))]) ub = torch.max(wdr, resample_wdrs[int(round(0.95 * num_resamples)) - 1]) b = torch.tensor( list( map( lambda a: a - ub if a > ub else (a - lb if a < lb else 0.0), gs.sum(0), )), device=self._device, ) b.unsqueeze_(0) bb = b * b.t() cov = omiga + bb # x = torch.rand((1, horizon), device=self.device, requires_grad=True) x = torch.zeros((1, horizon), device=self._device, requires_grad=True) # using SGD to find min x optimizer = torch.optim.SGD([x], lr=lr) last_y = 0.0 for i in range(100): x = torch.nn.functional.softmax(x, dim=1) y = torch.mm(torch.mm(x, cov), x.t()) if abs(y.item() - last_y) < loss_threhold: print(f"{i}: {last_y} -> {y.item()}") break last_y = y.item() optimizer.zero_grad() y.backward(retain_graph=True) optimizer.step() x = torch.nn.functional.softmax(x, dim=1) estimate = torch.mm(x, gs.sum(0, keepdim=True).t()) if input.ground_truth is not None: ground_truth = input.ground_truth(state) else: ground_truth = None results.append( EstimatorResult(self._log_reward(input.gamma, mdps), estimate, ground_truth)) logging.info(f"{self}: finishing evaluating[" f"process_time={time.process_time() - stime}]") return results
def _evaluate( self, input: BanditsEstimatorInput, train_samples: Sequence[LogSample], eval_samples: Sequence[LogSample], force_train: bool = False, **kwargs, ) -> Optional[EstimatorResult]: logger.info("OPE Switch Evaluating") self._train_model(train_samples, force_train) if "exp_base" in kwargs: exp_base = kwargs["exp_base"] else: exp_base = SwitchEstimator.EXP_BASE if "candidates" in kwargs: num_candidates = kwargs["candidates"] else: num_candidates = SwitchEstimator.CANDIDATES ( actions, ws, rs, r_est, r_est_for_logged_action, propensities, expected_rmax, log_avg, gt_avg, ) = self._calc_weight_reward_tensors(input, eval_samples) min_w, max_w = float(torch.min(ws).item()), float(torch.max(ws).item()) diff = max_w - min_w # The threshold lies in the range [min ips, max ips] # Picking a small threshold -> using mainly the model-based estimator # Picking a large threshold -> using mainly the ips-based estimator candidates = [ min_w + ((exp_base ** x) / (exp_base ** (num_candidates - 1))) * diff for x in range(num_candidates) ] # This prevents the edge case where nearly all scores being min_w prevents # switch from trying a purely DM estimate tau = min_w - SwitchEstimator.EPSILON loss = float("inf") for candidate in candidates: estimated_values = self._calc_estimated_values( rs, ws, actions, candidate, r_est, r_est_for_logged_action, propensities ) var = (1.0 / (estimated_values.shape[0] ** 2)) * torch.sum( (estimated_values - torch.mean(estimated_values)) ** 2 ).item() bias = torch.mean( torch.sum(expected_rmax * (ws > candidate).float(), dim=1, keepdim=True) ).item() cand_loss = var + bias * bias if cand_loss < loss: tau = candidate loss = cand_loss estimated_values = self._calc_estimated_values( rs, ws, actions, tau, r_est, r_est_for_logged_action, propensities ) ( tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(estimated_values.detach(), log_avg.average) return EstimatorResult( log_reward=log_avg.average, estimated_reward=torch.mean(estimated_values).item(), ground_truth_reward=gt_avg.average, estimated_weight=float(estimated_values.shape[0]), estimated_reward_normalized=tgt_score_normalized, estimated_reward_std_error=tgt_std_err, estimated_reward_normalized_std_error=tgt_std_err_normalized, )
def _evaluate( self, input: BanditsEstimatorInput, train_samples: Sequence[LogSample], eval_samples: Sequence[LogSample], **kwargs, ) -> Optional[EstimatorResult]: self._train_model(train_samples) ( actions, ws, rs, r_est, propensities, expected_rmax, log_avg, gt_avg, ) = self._calc_weight_reward_tensors(input, eval_samples) min_w, max_w = float(torch.min(ws).item()), float(torch.max(ws).item()) diff = max_w - min_w # The threshold lies in the range [min ips, max ips] # Picking a small threshold -> using mainly the model-based estimator # Picking a large threshold -> using mainly the ips-based estimator candidates = [ min_w + ((SwitchEstimator.EXP_BASE**x) / (SwitchEstimator.EXP_BASE**(SwitchEstimator.CANDIDATES - 1))) * diff for x in range(SwitchEstimator.CANDIDATES) ] tau = min_w loss = float("inf") for candidate in candidates: estimated_values = self._calc_estimated_values( rs, ws, actions, candidate, r_est, propensities) var = (1.0 / (estimated_values.shape[0]**2)) * torch.sum( (estimated_values - torch.mean(estimated_values))**2).item() bias = torch.mean( torch.sum(expected_rmax * (ws > candidate).float(), dim=1, keepdim=True)).item() cand_loss = var + bias * bias if cand_loss < loss: tau = candidate loss = cand_loss estimated_values = self._calc_estimated_values(rs, ws, actions, tau, r_est, propensities) ( tgt_score_normalized, tgt_std_err, tgt_std_err_normalized, ) = self._compute_metric_data(estimated_values, log_avg.average) return EstimatorResult( log_reward=log_avg.average, estimated_reward=torch.mean(estimated_values).item(), ground_truth_reward=gt_avg.average, estimated_weight=float(estimated_values.shape[0]), estimated_reward_normalized=tgt_score_normalized, estimated_reward_std_error=tgt_std_err, estimated_reward_normalized_std_error=tgt_std_err_normalized, )