def optimize(self, acqf: MCAcquisitionFunction) -> Tuple[Tensor, Tensor]: """ Optimizes the acquisition function :param acqf: The acquisition function being optimized :return: Best solution and value """ initial_conditions = self.generate_restart_points(acqf) # shape = num_restarts x *acqf.batch_shape x 1 x dim_X if self.inequality_constraints is not None: org_shape = initial_conditions.shape initial_conditions = initial_conditions.reshape( self.num_restarts, -1, self.dim_x) options = {"maxiter": int(self.maxiter / 25)} with settings.propagate_grads(True): solutions, values = gen_candidates_scipy( initial_conditions=initial_conditions, acquisition_function=acqf, lower_bounds=self.bounds[0], upper_bounds=self.bounds[1], options=options, inequality_constraints=self.inequality_constraints, ) self.add_solutions(solutions.view(-1, 1, self.dim_x).detach()) best_ind = torch.argmax(values, dim=0) if self.inequality_constraints is not None: solutions = solutions.reshape(org_shape) solution = solutions.gather( dim=0, index=best_ind.view(1, *best_ind.shape, 1, 1).repeat(*[1] * (best_ind.dim() + 2), self.dim_x), ) if self.inequality_constraints is not None: org_shape = solution.shape solution = solution.reshape(1, -1, self.dim_x) options = {"maxiter": self.maxiter} with settings.propagate_grads(True): solution, value = gen_candidates_scipy( initial_conditions=solution, acquisition_function=acqf, lower_bounds=self.bounds[0], upper_bounds=self.bounds[1], options=options, inequality_constraints=self.inequality_constraints, ) # This is needed due to nested optimization value = acqf(solution) if self.inequality_constraints is not None: solution = solution.reshape(org_shape) return solution, value.reshape(*acqf.batch_shape)
def fantasize( self, X: Tensor, sampler: MCSampler, observation_noise: bool = True, **kwargs: Any, ) -> Model: r"""Construct a fantasy model. Constructs a fantasy model in the following fashion: (1) compute the model posterior at `X` (including observation noise if `observation_noise=True`). (2) sample from this posterior (using `sampler`) to generate "fake" observations. (3) condition the model on the new fake observations. Args: X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of the feature space, `n'` is the number of points per batch, and `batch_shape` is the batch shape (must be compatible with the batch shape of the model). sampler: The sampler used for sampling from the posterior at `X`. observation_noise: If True, include observation noise. Returns: The constructed fantasy model. """ propagate_grads = kwargs.pop("propagate_grads", False) with settings.propagate_grads(propagate_grads): post_X = self.posterior(X, observation_noise=observation_noise) Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)
def forward(self, X: Tensor) -> Tensor: # Construct the fantasy model (we actually do not use the full model, # this is just a convenient way of computing fast posterior covariances fantasy_model = self.model.fantasize(X=X, sampler=self.sampler, observation_noise=True) bdims = tuple(1 for _ in X.shape[:-2]) if self.model.num_outputs > 1: # We use q=1 here b/c ScalarizedObjective currently does not fully exploit # lazy tensor operations and thus may be slow / overly memory-hungry. # TODO (T52818288): Properly use lazy tensors in scalarize_posterior mc_points = self.mc_points.view(-1, *bdims, 1, X.size(-1)) else: # While we only need marginal variances, we can evaluate for q>1 # b/c for GPyTorch models lazy evaluation can make this quite a bit # faster than evaluting in t-batch mode with q-batch size of 1 mc_points = self.mc_points.view(*bdims, -1, X.size(-1)) # evaluate the posterior at the grid points with settings.propagate_grads(True): posterior = fantasy_model.posterior( mc_points, posterior_transform=self.posterior_transform) neg_variance = posterior.variance.mul(-1.0) if self.posterior_transform is None: # if single-output, shape is 1 x batch_shape x num_grid_points x 1 return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0) else: # if multi-output + obj, shape is num_grid_points x batch_shape x 1 x 1 return neg_variance.mean(dim=0).squeeze(-1).squeeze(-1)
def test_propagate_grads(self): pgrads = settings.propagate_grads self.assertFalse(pgrads.on()) self.assertTrue(pgrads.off()) with settings.propagate_grads(True): self.assertTrue(pgrads.on()) self.assertFalse(pgrads.off()) self.assertFalse(pgrads.on()) self.assertTrue(pgrads.off())
def forward(self, X: Tensor) -> Tensor: r"""Evaluate qKnowledgeGradient on the candidate set `X`. Args: X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of `q + num_fantasies` design points each. We split this X tensor into two parts in the `q` dimension (`dim=-2`). The first `q` are the q-batch of design points and the last num_fantasies are the current solutions of the inner optimization problem. `X_fantasies = X[..., -num_fantasies:, :]` `X_fantasies.shape = b x num_fantasies x d` `X_actual = X[..., :-num_fantasies, :]` `X_actual.shape = b x q x d` Returns: A Tensor of shape `b`. For t-batch b, the q-KG value of the design `X_actual[b]` is averaged across the fantasy models, where `X_fantasies[b, i]` is chosen as the final selection for the `i`-th fantasy model. NOTE: If `current_value` is not provided, then this is not the true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be maximized at fixed `X_actual[b]`. """ X_actual, X_fantasies = _split_fantasy_points(X=X, n_f=self.num_fantasies) # We only concatenate X_pending into the X part after splitting if self.X_pending is not None: X_actual = torch.cat( [X_actual, match_batch_shape(self.X_pending, X_actual)], dim=-2) # construct the fantasy model of shape `num_fantasies x b` fantasy_model = self.model.fantasize(X=X_actual, sampler=self.sampler, observation_noise=True) # get the value function value_function = _get_value_function( model=fantasy_model, objective=self.objective, posterior_transform=self.posterior_transform, sampler=self.inner_sampler, ) # make sure to propagate gradients to the fantasy model train inputs with settings.propagate_grads(True): values = value_function(X=X_fantasies) # num_fantasies x b if self.current_value is not None: values = values - self.current_value # return average over the fantasy samples return values.mean(dim=0)
def forward(self, X: Tensor) -> Tensor: r""" Approximates E_n[CVaR[F]] as described in ApxCVaRKG. :param X: The decision variable `x` and the `\beta` value. Shape: batch x num_fantasies x num_starting_sols x 1 x (dim_x + 1) (see below) :return: -E_n[CVaR[F(x, W)]]. Shape: batch x num_fantasies x num_starting_sols Note that the return value is negated since the optimizers we use do maximization. """ if X.requires_grad: torch.set_grad_enabled(True) # ensure X has the correct dtype and device X = X.to(self.w_samples) # make sure X has proper shape, 4 dimensional to match the batch shape of rhoKG assert X.shape[-1] == self.dim_x + 1 if X.dim() < 4: X = X.reshape(-1, *self.model._input_batch_shape, 1, self.dim_x + 1) X_fant = X[..., :self.dim_x] # batch x num_fantasies x n x 1 x dim_x beta = X[..., -1:] # batch x num_fantasies x n x 1 x 1 # Join X_fant with w_samples z_fant = torch.cat( [ X_fant.repeat(*[1] * (X_fant.dim() - 2), self.num_samples, 1), self.w_samples.repeat(*X_fant.shape[:-2], 1, 1), ], dim=-1, ) # get posterior mean and std dev with settings.propagate_grads(True): posterior = self.model.posterior(z_fant) mu = posterior.mean sigma = torch.sqrt(posterior.variance) # Calculate `E_f[[f(x) - \beta]^+]` u = (mu - beta.expand_as(mu)) / sigma # this is from EI normal = Normal(torch.zeros_like(u), torch.ones_like(u)) ucdf = normal.cdf(u) updf = torch.exp(normal.log_prob(u)) values = sigma * (updf + u * ucdf) # take the expectation over W if getattr(self, "weights", None) is None: values = torch.mean(values, dim=-2) else: # Get the expectation with weights values = values * self.weights.unsqueeze(-1) values = torch.sum(values, dim=-2) # add beta and divide by 1-alpha values = beta.view_as(values) + values / (1 - self.alpha) # return with last dim squeezed # negated since CVaR is being minimized return -values.squeeze(-1)
def test_gpt_posterior_settings(self): for propagate_grads in (False, True): with settings.propagate_grads(propagate_grads): with gpt_posterior_settings(): self.assertTrue(gpt_settings.debug.off()) self.assertTrue(gpt_settings.fast_pred_var.on()) if settings.propagate_grads.off(): self.assertTrue(gpt_settings.detach_test_caches.on()) else: self.assertTrue(gpt_settings.detach_test_caches.off())
def fantasize(self, X, sampler, observation_noise=True, **kwargs): propagate_grads = kwargs.pop("propagate_grads", False) with settings.propagate_grads(propagate_grads): post_X = self.posterior(X, observation_noise=observation_noise, **kwargs) Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m # Use the mean of the previous noise values (TODO: be smarter here). # noise should be batch_shape x q x m when X is batch_shape x q x d, and # Y_fantasized is num_fantasies x batch_shape x q x m. noise_shape = Y_fantasized.shape[1:] noise = self.likelihood.noise.mean().expand(noise_shape) return self.condition_on_observations(X=X, Y=Y_fantasized, noise=noise)
def fantasize( self, X: Tensor, sampler: MCSampler, observation_noise: Union[bool, Tensor] = True, **kwargs: Any, ) -> FixedNoiseGP: r"""Construct a fantasy model. Constructs a fantasy model in the following fashion: (1) compute the model posterior at `X` (if `observation_noise=True`, this includes observation noise taken as the mean across the observation noise in the training data. If `observation_noise` is a Tensor, use it directly as the observation noise to add). (2) sample from this posterior (using `sampler`) to generate "fake" observations. (3) condition the model on the new fake observations. Args: X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of the feature space, `n'` is the number of points per batch, and `batch_shape` is the batch shape (must be compatible with the batch shape of the model). sampler: The sampler used for sampling from the posterior at `X`. observation_noise: If True, include the mean across the observation noise in the training data as observation noise in the posterior from which the samples are drawn. If a Tensor, use it directly as the specified measurement noise. Returns: The constructed fantasy model. """ propagate_grads = kwargs.pop("propagate_grads", False) with fantasize_flag(): with settings.propagate_grads(propagate_grads): post_X = self.posterior(X, observation_noise=observation_noise, **kwargs) Y_fantasized = sampler( post_X) # num_fantasies x batch_shape x n' x m # Use the mean of the previous noise values (TODO: be smarter here). # noise should be batch_shape x q x m when X is batch_shape x q x d, and # Y_fantasized is num_fantasies x batch_shape x q x m. noise_shape = Y_fantasized.shape[1:] noise = self.likelihood.noise.mean().expand(noise_shape) return self.condition_on_observations(X=self.transform_inputs(X), Y=Y_fantasized, noise=noise)
def forward(self, X: Tensor) -> Tensor: r"""Evaluate qMultiFidelityKnowledgeGradient on the candidate set `X`. Args: X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of `q + num_fantasies` design points each. We split this X tensor into two parts in the `q` dimension (`dim=-2`). The first `q` are the q-batch of design points and the last num_fantasies are the current solutions of the inner optimization problem. `X_fantasies = X[..., -num_fantasies:, :]` `X_fantasies.shape = b x num_fantasies x d` `X_actual = X[..., :-num_fantasies, :]` `X_actual.shape = b x q x d` In addition, `X` may be augmented with fidelity parameteres as part of thee `d`-dimension. Projecting fidelities to the target fidelity is handled by `project`. Returns: A Tensor of shape `b`. For t-batch b, the q-KG value of the design `X_actual[b]` is averaged across the fantasy models, where `X_fantasies[b, i]` is chosen as the final selection for the `i`-th fantasy model. NOTE: If `current_value` is not provided, then this is not the true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be maximized at fixed `X_actual[b]`. """ X_actual, X_fantasies = _split_fantasy_points(X=X, n_f=self.num_fantasies) # We only concatenate X_pending into the X part after splitting if self.X_pending is not None: X_eval = torch.cat( [X_actual, match_batch_shape(self.X_pending, X_actual)], dim=-2 ) else: X_eval = X_actual # construct the fantasy model of shape `num_fantasies x b` # expand X (to potentially add trace observations) fantasy_model = self.model.fantasize( X=self.expand(X_eval), sampler=self.sampler, observation_noise=True ) # get the value function value_function = _get_value_function( model=fantasy_model, objective=self.objective, sampler=self.inner_sampler ) # make sure to propagate gradients to the fantasy model train inputs # project the fantasy points with settings.propagate_grads(True): values = value_function(X=self.project(X_fantasies)) # num_fantasies x b if self.current_value is not None: values = values - self.current_value if self.cost_aware_utility is not None: values = self.cost_aware_utility( X=X_actual, deltas=values, sampler=self.cost_sampler ) # return average over the fantasy samples return values.mean(dim=0)
def forward(self, X: Tensor) -> Tensor: r""" Calculate the value of rhoKG acquisition function by averaging over fantasies :param X: `batch_size x q x dim` of solutions to evaluate :return: value of rhoKG at X (to be maximized) - size: batch_size """ # make sure X has proper shape X = X.reshape(-1, self.q, self.dim).to(dtype=self.dtype, device=self.device) batch_size = X.size(0) # generate w_samples if self.fix_samples: if self.fixed_samples is None: self.fixed_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) w_samples = self.fixed_samples else: w_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) if self.inner_seed is None: inner_seed = int(torch.randint(100000, (1,))) else: inner_seed = self.inner_seed # in an attempt to reduce the memory usage, we will evaluate in mini batches # of size mini_batch_size num_batches = ceil(batch_size / self.mini_batch_size) values = torch.empty(batch_size, dtype=self.dtype, device=self.device) if self.last_inner_solution is None: self.last_inner_solution = torch.empty( self.active_fantasies, batch_size, 1, self.dim_x, dtype=self.dtype, device=self.device, ) for i in range(num_batches): left_index = i * self.mini_batch_size if i == num_batches - 1: right_index = batch_size else: right_index = (i + 1) * self.mini_batch_size # construct the fantasy model fantasy_model = self.model.fantasize( X[left_index:right_index], self.sampler ) inner_rho = InnerRho( model=fantasy_model, w_samples=w_samples, alpha=self.alpha, dim_x=self.dim_x, num_repetitions=self.num_repetitions, inner_seed=inner_seed, CVaR=self.CVaR, expectation=self.expectation, weights=self.weights, ) # optimize inner VaR with settings.propagate_grads(True): if self.call_count % self.tts_frequency == 0: solution, value = self.inner_optimizer(inner_rho) self.last_inner_solution[:, left_index:right_index] = solution else: value = inner_rho( self.last_inner_solution[:, left_index:right_index] ) value = -value values[left_index:right_index] = self.current_best_rho - torch.mean( value, dim=0 ) self.call_count += 1 return values
def forward(self, X: Tensor) -> Tensor: """ The rhoKGapx algorithm for C/VaR. :param X: The tensor of candidate points, batch_size x q x dim :return: the rhoKGapx value of batch_size """ X = X.reshape(-1, self.q, self.dim).to(dtype=self.dtype, device=self.device) batch_size = X.size(0) # generate w_samples if self.fix_samples: if self.fixed_samples is None: self.fixed_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) w_samples = self.fixed_samples else: w_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) if self.inner_seed is None: inner_seed = int(torch.randint(100000, (1,))) else: inner_seed = self.inner_seed # in an attempt to reduce the memory usage, we will evaluate in mini batches # of size mini_batch_size num_batches = ceil(batch_size / self.mini_batch_size) values = torch.empty(batch_size) if self.last_inner_solution is None: self.last_inner_solution = torch.empty( self.active_fantasies, batch_size, 1, self.dim_x, dtype=self.dtype, device=self.device, ) for i in range(num_batches): left_index = i * self.mini_batch_size if i == num_batches - 1: right_index = batch_size else: right_index = (i + 1) * self.mini_batch_size # construct the fantasy model fantasy_model = self.model.fantasize( X[left_index:right_index], self.sampler ) inner_rho = InnerRho( model=fantasy_model, w_samples=w_samples, alpha=self.alpha, dim_x=self.dim_x, num_repetitions=self.num_repetitions, inner_seed=inner_seed, CVaR=self.CVaR, expectation=self.expectation, weights=self.weights, ) if self.call_count % self.tts_frequency == 0: x_comp = X[left_index:right_index, :, : self.dim_x] x_inner = torch.cat( (x_comp, self.past_x.repeat(right_index - left_index, 1, 1)), dim=-2 ).repeat(self.active_fantasies, 1, 1, 1) temp_values = torch.empty( self.past_x.size(0) + self.q, self.active_fantasies, right_index - left_index, dtype=self.dtype, device=self.device, ) for j in range(temp_values.size(0)): with settings.propagate_grads(True): temp_values[j] = -inner_rho(x_inner[..., j, :].unsqueeze(-2)) best = torch.argmin(temp_values, dim=0) detailed_values = torch.gather( temp_values, 0, best.unsqueeze(0) ).reshape(self.active_fantasies, right_index - left_index) self.last_inner_solution[:, left_index:right_index] = torch.gather( x_inner, 2, best.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, self.dim_x), ) else: detailed_values = -inner_rho( self.last_inner_solution[:, left_index:right_index] ) values[left_index:right_index] = self.current_best_rho - torch.mean( detailed_values, dim=0 ) self.call_count += 1 return values
def forward(self, X: Tensor) -> Tensor: r""" Calculate the value of VaRKG acquisition function by averaging over fantasies. NOTE: Does not return the value of rhoKG unless optimized! - Use rhoKG for that. :param X: `batch size x 1 x (q x dim + num_fantasies x dim_x)` of which the first `(q x dim)` is for q points being evaluated, the remaining `(num_fantasies x dim_x)` are the solutions to the inner problem. :return: value of rhoKG at X (to be maximized). shape: `batch size` """ warnings.warn("This is only experimental. Use rhoKGapx if possible!") # make sure X has proper shape X = X.reshape(-1, 1, X.shape[-1]).to(dtype=self.dtype, device=self.device) batch_size = X.shape[0] # split the evaluation and fantasy solutions split_sizes = [self.q * self.dim, self.num_fantasies * self.dim_x] if X.size(-1) != sum(split_sizes): raise ValueError( "X must be of size: batch size x 1 x (q x dim + num_fantasies x dim_x)" ) X_actual, X_fantasies = torch.split(X, split_sizes, dim=-1) X_actual = X_actual.reshape(batch_size, self.q, self.dim) # After permuting, we get size self.num_fantasies x batch size x 1 x dim_x X_fantasies = X_fantasies.reshape(batch_size, self.num_fantasies, self.dim_x) X_fantasies = X_fantasies.permute(1, 0, 2).unsqueeze(-2) # We use mini batches to reduce memory usage num_batches = ceil(batch_size / self.mini_batch_size) values = torch.empty(batch_size, dtype=self.dtype, device=self.device) # generate w_samples if self.fix_samples: if self.fixed_samples is None: self.fixed_samples = torch.rand((self.num_samples, self.dim_w), dtype=self.dtype, device=self.device) w_samples = self.fixed_samples else: w_samples = torch.rand((self.num_samples, self.dim_w), dtype=self.dtype, device=self.device) if self.inner_seed is None: inner_seed = int(torch.randint(100000, (1, ))) else: inner_seed = self.inner_seed w_actual = X_actual[..., -self.dim_w:] for i in range(num_batches): left_index = i * self.mini_batch_size if i == num_batches - 1: right_index = batch_size else: right_index = (i + 1) * self.mini_batch_size # construct the fantasy model fantasy_model = self.model.fantasize( X_actual[left_index:right_index], self.sampler) inner_rho = InnerRho( model=fantasy_model, w_samples=w_samples, alpha=self.alpha, dim_x=self.dim_x, num_repetitions=self.num_repetitions, inner_seed=inner_seed, CVaR=self.CVaR, expectation=self.expectation, w_actual=w_actual[left_index:right_index], weights=getattr(self, "weights", None), ) # sample and return with settings.propagate_grads(True): inner_values = -inner_rho( X_fantasies[:, left_index:right_index, :, :]) values[ left_index:right_index] = self.current_best_rho - torch.mean( inner_values, dim=0) return values
def forward(self, X: Tensor) -> Tensor: r""" Evaluate the value of the acquisition function on the given solution set. :param X: An `n x 1 x ( q * dim + num_fantasies * (dim_x + 1)` tensor of `q` candidates `x, w` and `num_fantasies` solutions `x` and `\beta` values for each fantasy model. :return: An `n`-dim tensor of acquisition function values """ if X.dim() == 2 and self.q == 1: X = X.unsqueeze(-2) if X.dim() != 3: raise ValueError("Only supports X.dim() = 3!") X = X.to(dtype=self.dtype, device=self.device) n = X.shape[0] # separate candidates and fantasy solutions X_actual = X[..., : self.q * self.dim].reshape(n, self.q, self.dim) X_rem = ( X[..., self.q * self.dim :] .reshape(n, self.num_fantasies, self.dim_x + 1) .permute(1, 0, 2) .unsqueeze(-2) ) # shape num_fantasies x n x 1 x dim_x + 1 X_fant = X_rem[..., : self.dim_x] # num_fantasies x n x 1 x dim_x beta = X_rem[..., -1:] # num_fantasies x n x 1 x 1 # generate w_samples if self.fix_samples: if self.fixed_samples is None: self.fixed_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) w_samples = self.fixed_samples else: w_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) # construct the fantasy model fantasy_model = self.model.fantasize(X_actual, self.sampler) # input shape of fantasy_model is `num_fantasies x n x * x dim` where * is the # number of solutions being evaluated jointly # Join X_fant with w_samples z_fant = torch.cat( [ X_fant.repeat(1, 1, self.num_samples, 1), w_samples.repeat(self.num_fantasies, n, 1, 1), ], dim=-1, ) # get posterior mean and std dev with settings.propagate_grads(True): posterior = fantasy_model.posterior(z_fant) mu = posterior.mean sigma = torch.sqrt(posterior.variance) # Calculate `E_f[[f(x) - \beta]^+]` u = (mu - beta.expand_as(mu)) / sigma # this is from EI normal = Normal(torch.zeros_like(u), torch.ones_like(u)) ucdf = normal.cdf(u) updf = torch.exp(normal.log_prob(u)) values = sigma * (updf + u * ucdf) # take the expectation over W if getattr(self, "weights", None) is None: values = torch.mean(values, dim=-2) else: # Get the expectation with weights values = values * self.weights.unsqueeze(-1) values = torch.sum(values, dim=-2) # add beta and divide by 1-alpha values = beta.view_as(values) + values / (1 - self.alpha) # expectation over fantasies values = torch.mean(values, dim=0) # return with last dim squeezed # negated since CVaR is being minimized return -values.squeeze(-1)
def forward(self, X: Tensor) -> Tensor: r""" Evaluate the value of the acquisition function on the given solution set. :param X: An `n x q x dim` tensor of `q` candidates `x, w`. :return: An `n`-dim tensor of acquisition function values """ if X.dim() == 2 and self.q == 1: X = X.unsqueeze(-2) if X.dim() != 3: raise ValueError("Only supports X.dim() = 3!") X = X.to(dtype=self.dtype, device=self.device) batch_size = X.size(0) # generate w_samples if self.fix_samples: if self.fixed_samples is None: self.fixed_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) w_samples = self.fixed_samples else: w_samples = torch.rand( (self.num_samples, self.dim_w), dtype=self.dtype, device=self.device ) # in an attempt to reduce the memory usage, we will evaluate in mini batches # of size mini_batch_size num_batches = ceil(batch_size / self.mini_batch_size) values = torch.empty(batch_size, dtype=self.dtype, device=self.device) if self.last_inner_solution is None: self.last_inner_solution = torch.empty( self.active_fantasies, batch_size, 1, self.dim_x + 1, dtype=self.dtype, device=self.device, ) for i in range(num_batches): left_index = i * self.mini_batch_size if i == num_batches - 1: right_index = batch_size else: right_index = (i + 1) * self.mini_batch_size # construct the fantasy model fantasy_model = self.model.fantasize( X[left_index:right_index], self.sampler ) inner_rho = InnerApxCVaR( model=fantasy_model, w_samples=w_samples, alpha=self.alpha, dim_x=self.dim_x, CVaR=self.CVaR, weights=self.weights, ) # optimize inner VaR with settings.propagate_grads(True): if self.call_count % self.tts_frequency == 0: solution, value = self.inner_optimizer(inner_rho, self.model) self.last_inner_solution[:, left_index:right_index] = solution else: value = inner_rho( self.last_inner_solution[:, left_index:right_index] ) value = -value if not X.requires_grad: value = value.detach() values[left_index:right_index] = self.current_best_rho - torch.mean( value, dim=0 ) self.call_count += 1 return values