def create_vector_encoder(data_dim,
                          latent_dim,
                          hidden_features=100,
                          num_blocks=2,
                          dropout_probability=0.0,
                          use_batch_norm=False,
                          context_features=None,
                          resnet=True):

    if resnet:
        encoder = nn_.ResidualNet(
            in_features=data_dim,
            out_features=latent_dim,
            hidden_features=hidden_features,
            context_features=context_features,
            num_blocks=num_blocks,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    else:
        encoder = nn_.MLP(
            in_shape=(data_dim, ),
            out_shape=(latent_dim, ),
            hidden_sizes=[hidden_features for _ in range(num_blocks)],
            context_features=context_features,
            activation=F.relu,
        )
    return encoder
Пример #2
0
def create_vector_encoder(data_dim,
                          latent_dim,
                          hidden_features=100,
                          num_blocks=2,
                          dropout_probability=0.0,
                          use_batch_norm=False,
                          context_features=None):
    encoder = nn_.ResidualNet(
        in_features=data_dim,
        out_features=latent_dim,
        hidden_features=hidden_features,
        context_features=context_features,
        num_blocks=num_blocks,
        activation=F.relu,
        dropout_probability=dropout_probability,
        use_batch_norm=use_batch_norm,
    )
    return encoder
Пример #3
0
def _create_vector_base_transform(
    i,
    base_transform_type,
    features,
    hidden_features,
    num_transform_blocks,
    dropout_probability,
    use_batch_norm,
    num_bins,
    tail_bound,
    apply_unconditional_transform,
    context_features,
):
    transform_net_create_fn = lambda in_features, out_features: nn_.ResidualNet(
        in_features=in_features,
        out_features=out_features,
        hidden_features=hidden_features,
        context_features=context_features,
        num_blocks=num_transform_blocks,
        activation=F.relu,
        dropout_probability=dropout_probability,
        use_batch_norm=use_batch_norm,
    )

    if base_transform_type == "affine-coupling":
        return transforms.AffineCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn)
    elif base_transform_type == "quadratic-coupling":
        return transforms.PiecewiseQuadraticCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            apply_unconditional_transform=apply_unconditional_transform,
        )
    elif base_transform_type == "rq-coupling":
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            apply_unconditional_transform=apply_unconditional_transform,
        )
    elif base_transform_type == "affine-autoregressive":
        return transforms.MaskedAffineAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    elif base_transform_type == "quadratic-autoregressive":
        return transforms.MaskedPiecewiseQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    elif base_transform_type == "rq-autoregressive":
        return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    else:
        raise ValueError