Exemple #1
0
def build_affine_coupling(
    n_dim, n_coupling, hidden_layers, hidden_features, dropout_fraction
):
    layers = []
    for _ in range(n_coupling):
        p = transforms.RandomPermutation(n_dim, 1)
        mask_even = utils.create_alternating_binary_mask(features=n_dim, even=True)
        mask_odd = utils.create_alternating_binary_mask(features=n_dim, even=False)
        t1 = transforms.AffineCouplingTransform(
            mask=mask_even,
            transform_net_create_fn=lambda in_features, out_features: nn.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=hidden_features,
                num_blocks=hidden_layers,
                dropout_probability=dropout_fraction,
                use_batch_norm=False,
            ),
        )
        t2 = transforms.AffineCouplingTransform(
            mask=mask_odd,
            transform_net_create_fn=lambda in_features, out_features: nn.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=hidden_features,
                num_blocks=hidden_layers,
                dropout_probability=dropout_fraction,
                use_batch_norm=False,
            ),
        )
        layers.append(p)
        layers.append(t1)
        layers.append(t2)
    return layers
Exemple #2
0
def build_nsf_coupling(n_dim, n_coupling, spline_points, hidden_layers,
                       hidden_features, dropout_fraction):
    layers = []
    for _ in range(n_coupling):
        p = transforms.RandomPermutation(n_dim, 1)
        mask_even = utils.create_alternating_binary_mask(features=n_dim,
                                                         even=True)
        mask_odd = utils.create_alternating_binary_mask(features=n_dim,
                                                        even=False)
        t1 = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask_even,
            transform_net_create_fn=lambda in_features, out_features: nn.
            ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=hidden_features,
                num_blocks=hidden_layers,
                dropout_probability=dropout_fraction,
                use_batch_norm=False,
            ),
            tails="linear",
            tail_bound=15,
            num_bins=spline_points,
            apply_unconditional_transform=False,
        )
        t2 = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask_odd,
            transform_net_create_fn=lambda in_features, out_features: nn.
            ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=hidden_features,
                num_blocks=hidden_layers,
                dropout_probability=dropout_fraction,
                use_batch_norm=False,
            ),
            tails="linear",
            tail_bound=15,
            num_bins=spline_points,
            apply_unconditional_transform=False,
        )
        layers.append(p)
        layers.append(t1)
        layers.append(t2)
    return layers
Exemple #3
0
print("Trajectory loaded")
print("Data has size:", training_data.shape)


# Build the network
N_COUPLING = 4
AFFINE_LAYER = False
layers = []

# Create the mixed transofrm layer
pca_block = protein.PCABlock("backbone", True)
mixed = protein.MixedTransform(n_dim, t.topology, [pca_block], training_data)
layers.append(mixed)

for _ in range(N_COUPLING):
    p = transforms.RandomPermutation(n_dim - 6, 1)
    mask_even = utils.create_alternating_binary_mask(features=n_dim - 6, even=True)
    mask_odd = utils.create_alternating_binary_mask(features=n_dim - 6, even=False)
    if AFFINE_LAYER:
        t1 = transforms.AffineCouplingTransform(
            mask=mask_even,
            transform_net_create_fn=lambda in_features, out_features: nn.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=128,
                num_blocks=3,
                dropout_probability=0.5,
                use_batch_norm=True,
            ),
        )
        t2 = transforms.AffineCouplingTransform(