def _get_induced_fantasy_model(model: Model, Xs: List[Tensor], samplers: List[Optional[MCSampler]]) -> Model: r"""Recursive computation of the fantasy model induced by an input tree. Args: model: A Model of appropriate batch size. Specifically, it must be possible to evaluate the model's posterior at `Xs[0]`. Xs: A list `[X_j, ..., X_k]` of tensors, where `X_i` has shape `f_i x .... x f_1 x batch_shape x q_i x d`. samplers: A list of `k - j` samplers, such that the number of samples of sampler `i` is `f_i`. The last element of this list is considered the "inner sampler", which is used for evaluating the objective in case it is an MCAcquisitionObjective. Returns: A Model obtained by iteratively fantasizing over the input tree `Xs`. """ if len(Xs) == 1: return model else: fantasy_model = model.fantasize( X=Xs[0], sampler=samplers[0], observation_noise=True, ) return _get_induced_fantasy_model(model=fantasy_model, Xs=Xs[1:], samplers=samplers[1:])
def _step( model: Model, Xs: List[Tensor], samplers: List[Optional[MCSampler]], valfunc_cls: List[Optional[Type[AcquisitionFunction]]], valfunc_argfacs: List[Optional[TAcqfArgConstructor]], inner_samplers: List[Optional[MCSampler]], objective: MCAcquisitionObjective, posterior_transform: PosteriorTransform, running_val: Optional[Tensor] = None, sample_weights: Optional[Tensor] = None, step_index: int = 0, ) -> Tensor: r"""Recursive multi-step look-ahead computation. Helper function computing the "value-to-go" of a multi-step lookahead scheme. Args: model: A Model of appropriate batch size. Specifically, it must be possible to evaluate the model's posterior at `Xs[0]`. Xs: A list `[X_j, ..., X_k]` of tensors, where `X_i` has shape `f_i x .... x f_1 x batch_shape x q_i x d`. samplers: A list of `k - j` samplers, such that the number of samples of sampler `i` is `f_i`. The last element of this list is considered the "inner sampler", which is used for evaluating the objective in case it is an MCAcquisitionObjective. valfunc_cls: A list of acquisition function class to be used as the (stage + terminal) value functions. Each element (except for the last one) can be `None`, in which case a zero stage value is assumed for the respective stage. valfunc_argfacs: A list of callables that map a `Model` and input tensor `X` to a dictionary of kwargs for the respective stage value function constructor. If `None`, only the standard `model`, `sampler` and `objective` kwargs will be used. inner_samplers: A list of `MCSampler` objects, each to be used in the stage value function at the corresponding index. objective: The MCAcquisitionObjective under which the model output is evaluated. posterior_transform: A PosteriorTransform. Used to transform the posterior before sampling / evaluating the model output. running_val: As `batch_shape`-dim tensor containing the current running value. sample_weights: A tensor of shape `f_i x .... x f_1 x batch_shape` when called in the `i`-th step by which to weight the stage value samples. Used in conjunction with Gauss-Hermite integration or importance sampling. Assumed to be `None` in the initial step (when `step_index=0`). step_index: The index of the look-ahead step. `step_index=0` indicates the initial step. Returns: A `b`-dim tensor containing the multi-step value of the design `X`. """ X = Xs[0] if sample_weights is None: # only happens in the initial step sample_weights = torch.ones(*X.shape[:-2], device=X.device, dtype=X.dtype) # compute stage value stage_val = _compute_stage_value( model=model, valfunc_cls=valfunc_cls[0], X=X, objective=objective, posterior_transform=posterior_transform, inner_sampler=inner_samplers[0], arg_fac=valfunc_argfacs[0], ) if stage_val is not None: # update running value # if not None, running_val has shape f_{i-1} x ... x f_1 x batch_shape # stage_val has shape f_i x ... x f_1 x batch_shape # this sum will add a dimension to running_val so that # updated running_val has shape f_i x ... x f_1 x batch_shape running_val = stage_val if running_val is None else running_val + stage_val # base case: no more fantasizing, return value if len(Xs) == 1: # compute weighted average over all leaf nodes of the tree batch_shape = running_val.shape[step_index:] # expand sample weights to make sure it is the same shape as running_val, # because we need to take a sum over sample weights for computing the # weighted average sample_weights = sample_weights.expand(running_val.shape) return (running_val * sample_weights).view(-1, *batch_shape).sum(dim=0) # construct fantasy model (with batch shape f_{j+1} x ... x f_1 x batch_shape) prop_grads = step_index > 0 # need to propagate gradients for steps > 0 fantasy_model = model.fantasize(X=X, sampler=samplers[0], observation_noise=True, propagate_grads=prop_grads) # augment sample weights appropriately sample_weights = _construct_sample_weights(prev_weights=sample_weights, sampler=samplers[0]) return _step( model=fantasy_model, Xs=Xs[1:], samplers=samplers[1:], valfunc_cls=valfunc_cls[1:], valfunc_argfacs=valfunc_argfacs[1:], inner_samplers=inner_samplers[1:], objective=objective, posterior_transform=posterior_transform, sample_weights=sample_weights, running_val=running_val, step_index=step_index + 1, )