def _create_vector_base_transform( i, base_transform_type, features, hidden_features, num_transform_blocks, dropout_probability, use_batch_norm, num_bins, tail_bound, apply_unconditional_transform, context_features, ): transform_net_create_fn = lambda in_features, out_features: nn_.ResidualNet( in_features=in_features, out_features=out_features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_transform_blocks, activation=F.relu, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) if base_transform_type == "affine-coupling": return transforms.AffineCouplingTransform( mask=various.create_alternating_binary_mask(features, even=(i % 2 == 0)), transform_net_create_fn=transform_net_create_fn) elif base_transform_type == "quadratic-coupling": return transforms.PiecewiseQuadraticCouplingTransform( mask=various.create_alternating_binary_mask(features, even=(i % 2 == 0)), transform_net_create_fn=transform_net_create_fn, num_bins=num_bins, tails="linear", tail_bound=tail_bound, apply_unconditional_transform=apply_unconditional_transform, ) elif base_transform_type == "rq-coupling": return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=various.create_alternating_binary_mask(features, even=(i % 2 == 0)), transform_net_create_fn=transform_net_create_fn, num_bins=num_bins, tails="linear", tail_bound=tail_bound, apply_unconditional_transform=apply_unconditional_transform, ) elif base_transform_type == "affine-autoregressive": return transforms.MaskedAffineAutoregressiveTransform( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) elif base_transform_type == "quadratic-autoregressive": return transforms.MaskedPiecewiseQuadraticAutoregressiveTransform( features=features, hidden_features=hidden_features, context_features=context_features, num_bins=num_bins, tails="linear", tail_bound=tail_bound, num_blocks=num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) elif base_transform_type == "rq-autoregressive": return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform( features=features, hidden_features=hidden_features, context_features=context_features, num_bins=num_bins, tails="linear", tail_bound=tail_bound, num_blocks=num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) else: raise ValueError
def _create_image_transform_step( num_channels, hidden_channels=96, context_channels=None, actnorm=True, coupling_layer_type="rational_quadratic_spline", num_res_blocks=3, resnet_batchnorm=True, dropout_prob=0.0, num_bins=8, tail_bound=3.0, ): def create_convnet(in_channels, out_channels): net = nn_.ConvResidualNet( in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels, context_channels=context_channels, num_blocks=num_res_blocks, use_batch_norm=resnet_batchnorm, dropout_probability=dropout_prob, ) return net mask = various.create_mid_split_binary_mask(num_channels) if coupling_layer_type == "cubic_spline": coupling_layer = transforms.PiecewiseCubicCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=False, min_bin_width=0.001, min_bin_height=0.001, ) elif coupling_layer_type == "quadratic_spline": coupling_layer = transforms.PiecewiseQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=False, min_bin_width=0.001, min_bin_height=0.001, ) elif coupling_layer_type == "rational_quadratic_spline": coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=False, min_bin_width=0.001, min_bin_height=0.001, min_derivative=0.001, ) elif coupling_layer_type == "affine": coupling_layer = transforms.AffineCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) elif coupling_layer_type == "additive": coupling_layer = transforms.AdditiveCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) else: raise RuntimeError("Unknown coupling_layer_type") step_transforms = [] if actnorm: step_transforms.append(transforms.ActNorm(num_channels)) step_transforms.extend( [transforms.OneByOneConvolution(num_channels), coupling_layer]) logger.debug(" Flow based on %s", coupling_layer_type) return transforms.CompositeTransform(step_transforms)
def _create_image_transform_step( num_channels, hidden_channels=96, actnorm=True, coupling_layer_type="rational_quadratic_spline", spline_params=None, use_resnet=True, num_res_blocks=3, resnet_batchnorm=True, dropout_prob=0.0, ): if use_resnet: def create_convnet(in_channels, out_channels): net = nn_.ConvResidualNet( in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels, num_blocks=num_res_blocks, use_batch_norm=resnet_batchnorm, dropout_probability=dropout_prob, ) return net else: if dropout_prob != 0.0: raise ValueError() def create_convnet(in_channels, out_channels): return ConvNet(in_channels, hidden_channels, out_channels) if spline_params is None: spline_params = { "apply_unconditional_transform": False, "min_bin_height": 0.001, "min_bin_width": 0.001, "min_derivative": 0.001, "num_bins": 4, "tail_bound": 3.0, } mask = various.create_mid_split_binary_mask(num_channels) if coupling_layer_type == "cubic_spline": coupling_layer = transforms.PiecewiseCubicCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=spline_params["tail_bound"], num_bins=spline_params["num_bins"], apply_unconditional_transform=spline_params[ "apply_unconditional_transform"], min_bin_width=spline_params["min_bin_width"], min_bin_height=spline_params["min_bin_height"], ) elif coupling_layer_type == "quadratic_spline": coupling_layer = transforms.PiecewiseQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=spline_params["tail_bound"], num_bins=spline_params["num_bins"], apply_unconditional_transform=spline_params[ "apply_unconditional_transform"], min_bin_width=spline_params["min_bin_width"], min_bin_height=spline_params["min_bin_height"], ) elif coupling_layer_type == "rational_quadratic_spline": coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails="linear", tail_bound=spline_params["tail_bound"], num_bins=spline_params["num_bins"], apply_unconditional_transform=spline_params[ "apply_unconditional_transform"], min_bin_width=spline_params["min_bin_width"], min_bin_height=spline_params["min_bin_height"], min_derivative=spline_params["min_derivative"], ) elif coupling_layer_type == "affine": coupling_layer = transforms.AffineCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) elif coupling_layer_type == "additive": coupling_layer = transforms.AdditiveCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) else: raise RuntimeError("Unknown coupling_layer_type") step_transforms = [] if actnorm: step_transforms.append(transforms.ActNorm(num_channels)) step_transforms.extend( [transforms.OneByOneConvolution(num_channels), coupling_layer]) logger.debug(" Flow based on %s", coupling_layer_type) return transforms.CompositeTransform(step_transforms)