Example #1
0
 def build_fn(batch_theta, batch_x):
     if model == "mdn":
         return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs)
     if model == "made":
         return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
     if model == "maf":
         return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
     elif model == "nsf":
         return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
     else:
         raise NotImplementedError
Example #2
0
File: mnle.py Project: bkmi/sbi
def build_mnle(
    batch_x: Tensor,
    batch_y: Tensor,
    z_score_x: Optional[str] = "independent",
    z_score_y: Optional[str] = "independent",
    num_transforms: int = 2,
    num_bins: int = 5,
    hidden_features: int = 50,
    hidden_layers: int = 2,
    tail_bound: float = 10.0,
    log_transform_x: bool = True,
    **kwargs,
):
    """Returns a density estimator for mixed data types.

    Uses a categorical net to model the discrete part and a neural spline flow (NSF) to
    model the continuous part of the data.

    Args:
        batch_x: batch of data
        batch_y: batch of parameters
        z_score_x: whether to z-score x.
        z_score_y: whether to z-score y.
        num_transforms: number of transforms in the NSF
        num_bins: bins per spline for NSF.
        hidden_features: number of hidden features used in both nets.
        hidden_layers: number of hidden layers in the categorical net.
        tail_bound: spline tail bound for NSF.
        log_transform_x: whether to apply a log-transform to x to move it to unbounded
            space, e.g., in case x consists of reaction time data (bounded by zero).

    Returns:
        MixedDensityEstimator: nn.Module for performing MNLE.
    """

    check_data_device(batch_x, batch_y)
    if z_score_y == "independent":
        embedding = standardizing_net(batch_y)
    else:
        embedding = None

    warnings.warn(
        """The mixed neural likelihood estimator assumes that x contains
        continuous data in the first n-1 columns (e.g., reaction times) and
        categorical data in the last column (e.g., corresponding choices). If
        this is not the case for the passed `x` do not use this function.""")
    # Separate continuous and discrete data.
    cont_x, disc_x = _separate_x(batch_x)

    # Infer input and output dims.
    dim_parameters = batch_y[0].numel()
    num_categories = unique(disc_x).numel()

    # Set up a categorical RV neural net for modelling the discrete data.
    disc_nle = CategoricalNet(
        num_input=dim_parameters,
        num_categories=num_categories,
        num_hidden=hidden_features,
        num_layers=hidden_layers,
        embedding=embedding,
    )

    # Set up a NSF for modelling the continuous data, conditioned on the discrete data.
    cont_nle = build_nsf(
        batch_x=torch.log(cont_x)
        if log_transform_x else cont_x,  # log transform manually.
        batch_y=torch.cat((batch_y, disc_x),
                          dim=1),  # condition on discrete data too.
        z_score_y=z_score_y,
        z_score_x=z_score_x,
        num_bins=num_bins,
        num_transforms=num_transforms,
        tail_bound=tail_bound,
        hidden_features=hidden_features,
    )

    return MixedDensityEstimator(
        discrete_net=disc_nle,
        continuous_net=cont_nle,
        log_transform_x=log_transform_x,
    )