def create_transform(c, h, w, num_bits, levels, steps_per_level, step_config): mct = transforms.MultiscaleCompositeTransform(num_transforms=levels) for level in range(levels): squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) transform_level = transforms.CompositeTransform( [squeeze_transform] + [create_transform_step(c, **step_config) for _ in range(steps_per_level)] + [transforms.OneByOneConvolution(c)] # End each level with a linear transformation. ) new_shape = mct.add_transform(transform_level, (c, h, w)) if new_shape: # If not last layer c, h, w = new_shape # Map to [-0.5,0.5] preprocess_transform = transforms.AffineScalarTransform(scale=(1. / 2 ** num_bits), shift=-0.5) transform = transforms.CompositeTransform([ preprocess_transform, mct ]) return transform
def create_transform(num_flow_steps, param_dim, context_dim, base_transform_kwargs): """Build a sequence of NSF transforms, which maps parameters x into the base distribution u (noise). Transforms are conditioned on strain data y. Note that the forward map is f^{-1}(x, y). Each step in the sequence consists of * A linear transform of x, which in particular permutes components * A NSF transform of x, conditioned on y. There is one final linear transform at the end. This function was adapted from the uci.py example in https://github.com/bayesiains/nsf Arguments: num_flow_steps {int} -- number of transforms in sequence param_dim {int} -- dimensionality of x context_dim {int} -- dimensionality of y base_transform_kwargs {dict} -- hyperparameters for NSF step Returns: Transform -- the constructed transform """ transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform(param_dim), create_base_transform( i, param_dim, context_dim=context_dim, **base_transform_kwargs) ]) for i in range(num_flow_steps) ] + [create_linear_transform(param_dim)]) return transform
def create_linear_transform(linear_transform, features): """Function for creating linear transforms. Parameters ---------- linear_transform : {'permutation', 'lu', 'svd'} Linear transform to use. featres : int Number of features. """ if linear_transform.lower() == 'permutation': return transforms.RandomPermutation(features=features) elif linear_transform.lower() == 'lu': return transforms.CompositeTransform([ transforms.RandomPermutation(features=features), LULinear(features, identity_init=True, using_cache=True) ]) elif linear_transform.lower() == 'svd': return transforms.CompositeTransform([ transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10, identity_init=True) ]) else: raise ValueError(f'Unknown linear transform: {linear_transform}. ' 'Choose from: {permutation, lu, svd}.')
def _create_transform(self, flow_steps, context_features=None): return transforms.CompositeTransform([ transforms.CompositeTransform([ self._create_linear_transform(), self._create_maf_transform(context_features) ]) for i in range(flow_steps) ] + [self._create_linear_transform()])
def __init__(self, features, hidden_features, num_layers, num_blocks_per_layer, num_bins=8, context_features=None, activation=F.relu, dropout_probability=0.0, batch_norm_within_layers=False, batch_norm_between_layers=False, apply_unconditional_transform=False, linear_transform='permutation', tails='linear', tail_bound=5.0, **kwargs): if features <= 1: raise ValueError( 'Coupling based Neural Spline flow requires at least 2 ' f'dimensions. Specified dimensions: {features}.') def create_resnet(in_features, out_features): return ResidualNet( in_features, out_features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_blocks_per_layer, activation=activation, dropout_probability=dropout_probability, use_batch_norm=batch_norm_within_layers, ) def spline_constructor(i): return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=create_alternating_binary_mask(features=features, even=(i % 2 == 0)), transform_net_create_fn=create_resnet, num_bins=num_bins, apply_unconditional_transform=apply_unconditional_transform, tails=tails, tail_bound=tail_bound, **kwargs) transforms_list = [] for i in range(num_layers): if linear_transform is not None: transforms_list.append( create_linear_transform(linear_transform, features)) transforms_list.append(spline_constructor(i)) if batch_norm_between_layers: transforms_list.append(transforms.BatchNorm(features=features)) distribution = StandardNormal([features]) super().__init__( transform=transforms.CompositeTransform(transforms_list), distribution=distribution, )
def _make_scalar_linear_transform(transform, features): if transform == "permutation": return transforms.RandomPermutation(features=features) elif transform == "lu": return transforms.CompositeTransform( [transforms.RandomPermutation(features=features), transforms.LULinear(features, identity_init=True)] ) elif transform == "svd": return transforms.CompositeTransform( [ transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10, identity_init=True), ] ) else: raise ValueError
def __init__( self, x_size: int, y_size: int, arch: str = 'A', # ['PRQ', 'UMNN'] num_transforms: int = 5, lu_linear: bool = False, moments: Tuple[torch.Tensor, torch.Tensor] = None, **kwargs, ): kwargs.setdefault('hidden_features', 64) kwargs.setdefault('num_blocks', 2) kwargs.setdefault('use_residual_blocks', False) kwargs.setdefault('use_batch_norm', False) kwargs['activation'] = ACTIVATIONS[kwargs.get('activation', 'ReLU')]() if arch == 'PRQ': kwargs['tails'] = 'linear' kwargs.setdefault('num_bins', 8) kwargs.setdefault('tail_bound', 1.) tfrm = transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform elif arch == 'UMNN': kwargs.setdefault('integrand_net_layers', [64, 64, 64]) kwargs.setdefault('cond_size', 32) kwargs.setdefault('nb_steps', 32) tfrm = transforms.MaskedUMNNAutoregressiveTransform else: # arch == 'A' tfrm = transforms.MaskedAffineAutoregressiveTransform compose = [] if moments is not None: shift, scale = moments compose.append( transforms.PointwiseAffineTransform(-shift / scale, 1 / scale)) for _ in range(num_transforms if x_size > 1 else 1): compose.extend([ tfrm( features=x_size, context_features=y_size, **kwargs, ), transforms.RandomPermutation(features=x_size), ]) if lu_linear: compose.append(transforms.LULinear(x_size, identity_init=True), ) transform = transforms.CompositeTransform(compose) distribution = distributions.StandardNormal((x_size, )) super().__init__(transform, distribution)
def create_flow(flow_type): distribution = distributions.StandardNormal((3,)) if flow_type == 'lu_flow': transform = transforms.CompositeTransform([ transforms.RandomPermutation(3), transforms.LULinear(3, identity_init=False) ]) elif flow_type == 'qr_flow': transform = transforms.QRLinear(3, num_householder=3) else: raise RuntimeError('Unknown type') return flows.Flow(transform, distribution)
def create_linear_transform(param_dim): """Create the composite linear transform PLU. Arguments: input_dim {int} -- dimension of the space Returns: Transform -- nde.Transform object """ return transforms.CompositeTransform([ transforms.RandomPermutation(features=param_dim), transforms.LULinear(param_dim, identity_init=True) ])
def create_transform_step(num_channels, hidden_channels, num_res_blocks, resnet_batchnorm, dropout_prob, actnorm, spline_type, num_bins): def create_convnet(in_channels, out_channels): net = ConvResidualNet(in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels, num_blocks=num_res_blocks, use_batch_norm=resnet_batchnorm, dropout_probability=dropout_prob) return net step_transforms = [] mask = utils.create_mid_split_binary_mask(num_channels) if spline_type == 'rational_quadratic': coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, num_bins=num_bins, tails='linear' ) elif spline_type == 'quadratic': coupling_layer = transforms.PiecewiseQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, num_bins=num_bins, tails='linear' ) else: raise RuntimeError('Unkown spline_type') if actnorm: step_transforms.append(transforms.ActNorm(num_channels)) step_transforms.extend([ transforms.OneByOneConvolution(num_channels), coupling_layer ]) return transforms.CompositeTransform(step_transforms)
def make_scalar_flow( dim, flow_steps=5, transform_type="rq", linear_transform="none", bins=10, tail_bound=10.0, hidden_features=64, num_transform_blocks=3, use_batch_norm=False, dropout_prob=0.0, ): logger.info( f"Creating flow for {dim}-dimensional unstructured data, using {flow_steps} blocks of {transform_type} transforms, " f"each with {num_transform_blocks} transform blocks and {hidden_features} hidden units, interlaced with {linear_transform} " f"linear transforms" ) base_dist = distributions.StandardNormal((dim,)) transform = [] for i in range(flow_steps): if linear_transform != "none": transform.append(_make_scalar_linear_transform(linear_transform, dim)) transform.append( _make_scalar_base_transform( i, dim, transform_type, bins, tail_bound, hidden_features, num_transform_blocks, use_batch_norm, dropout_prob=dropout_prob, ) ) if linear_transform != "none": transform.append(_make_scalar_linear_transform(linear_transform, dim)) transform = transforms.CompositeTransform(transform) flow = flows.Flow(transform, base_dist) return flow
def make_flow(latent_dim, n_layers): transform_list = [BatchNormTransform(latent_dim)] for _ in range(n_layers): transform_list.extend( [ transforms.MaskedAffineAutoregressiveTransform( features=latent_dim, hidden_features=64, ), transforms.RandomPermutation(latent_dim), ] ) transform = transforms.CompositeTransform(transform_list) # Define a base distribution. base_distribution = distributions.StandardNormal(shape=[latent_dim]) # Combine into a flow. return flows.Flow(transform=transform, distribution=base_distribution)
def __init__( self, features, hidden_features, num_layers, num_blocks_per_layer, mask=None, context_features=None, net='resnet', use_volume_preserving=False, activation=F.relu, dropout_probability=0.0, batch_norm_within_layers=False, batch_norm_between_layers=False, linear_transform=None, distribution=None, ): if features <= 1: raise ValueError('RealNVP requires at least 2 dimensions. ' f'Specified dimensions: {features}.') if use_volume_preserving: coupling_constructor = transforms.AdditiveCouplingTransform else: coupling_constructor = transforms.AffineCouplingTransform if mask is None: mask = torch.ones(features) mask[::2] = -1 else: mask = np.array(mask) if not mask.shape[-1] == features: raise ValueError('Mask does not match number of features') if mask.ndim == 2 and not mask.shape[0] == num_layers: raise ValueError('Mask does not match number of layers') mask = torch.from_numpy(mask).type(torch.get_default_dtype()) if mask.dim() == 1: mask_array = torch.empty([num_layers, features]) for i in range(num_layers): mask_array[i] = mask mask *= -1 mask = mask_array if net.lower() == 'resnet': from nflows.nn.nets import ResidualNet def create_net(in_features, out_features): return ResidualNet(in_features, out_features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_blocks_per_layer, activation=activation, dropout_probability=dropout_probability, use_batch_norm=batch_norm_within_layers) elif net.lower() == 'mlp': from .utils import MLP if batch_norm_within_layers: logger.warning('Batch norm within layers not supported for ' 'MLP, will be ignored') if dropout_probability: logger.warning('Dropout not supported for MLP, ' 'will be ignored') hidden_features = num_blocks_per_layer * [hidden_features] def create_net(in_features, out_features): return MLP((in_features, ), (out_features, ), hidden_features, activation=activation) else: raise ValueError(f'Unknown nn type: {net}. ' 'Choose from: {resnet, mlp}.') layers = [] for i in range(num_layers): if linear_transform is not None: layers.append( create_linear_transform(linear_transform, features)) transform = coupling_constructor( mask=mask[i], transform_net_create_fn=create_net) layers.append(transform) if batch_norm_between_layers: layers.append(transforms.BatchNorm(features=features)) if distribution is None: distribution = StandardNormal([features]) super().__init__( transform=transforms.CompositeTransform(layers), distribution=distribution, )
def _create_linear_transform(self): return transforms.CompositeTransform([ transforms.RandomPermutation(features=self.dimensions), transforms.LULinear(self.dimensions, identity_init=True) ])
def neural_net_nsf( self, hidden_features, num_blocks, num_bins, xDim, thetaDim, batch_x=None, batch_theta=None, tail=3., bounded=False, embedding_net=torch.nn.Identity()) -> torch.nn.Module: """Builds NSF p(x|y). Args: batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. z_score_x: Whether to z-score xs passing into the network. z_score_y: Whether to z-score ys passing into the network. hidden_features: Number of hidden features. num_transforms: Number of transforms. embedding_net: Optional embedding network for y. kwargs: Additional arguments that are passed by the build function but are not relevant for maf and are therefore ignored. Returns: Neural network. """ basic_transform = [ transforms.CompositeTransform([ transforms.PiecewiseRationalQuadraticCouplingTransform( mask=create_alternating_binary_mask(features=xDim, even=(i % 2 == 0)).to( self.args.device), transform_net_create_fn=lambda in_features, out_features: nets. ResidualNet( in_features=in_features, out_features=out_features, hidden_features=hidden_features, context_features=thetaDim, num_blocks=2, activation=torch.relu, dropout_probability=0., use_batch_norm=False, ), num_bins=num_bins, tails='linear', tail_bound=tail, apply_unconditional_transform=False, ), transforms.RandomPermutation(features=xDim, device=self.args.device), transforms.LULinear(xDim, identity_init=True), ]) for i in range(num_blocks) ] transform = transforms.CompositeTransform(basic_transform).to( self.args.device) if batch_theta != None: if bounded: transform_bounded = transforms.Logit(self.args.device) if self.sim.min[0].item() != 0 or self.sim.max[0].item() != 1: transfomr_affine = transforms.PointwiseAffineTransform( shift=-self.sim.min / (self.sim.max - self.sim.min), scale=1. / (self.sim.max - self.sim.min)) transform = transforms.CompositeTransform( [transfomr_affine, transform_bounded, transform]) else: transform = transforms.CompositeTransform( [transform_bounded, transform]) else: transform_zx = standardizing_transform(batch_x) transform = transforms.CompositeTransform( [transform_zx, transform]) embedding_net = torch.nn.Sequential(standardizing_net(batch_theta), embedding_net) distribution = distributions_.StandardNormal((xDim, ), self.args.device) neural_net = flows.Flow(self, transform, distribution, embedding_net=embedding_net).to( self.args.device) else: distribution = distributions_.StandardNormal((xDim, ), self.args.device) neural_net = flows.Flow(self, transform, distribution).to(self.args.device) return neural_net
def make_image_flow( chw, levels=7, steps_per_level=3, transform_type="rq", bins=4, tail_bound=3.0, hidden_channels=96, act_norm=True, batch_norm=False, dropout_prob=0.0, alpha=0.05, num_bits=8, preprocessing="glow", residual_blocks=3, ): c, h, w = chw if not isinstance(hidden_channels, list): hidden_channels = [hidden_channels] * levels # Base density base_dist = distributions.StandardNormal((c * h * w,)) logger.debug(f"Base density: standard normal in {c * h * w} dimensions") # Preprocessing: Inputs to the model in [0, 2 ** num_bits] if preprocessing == "glow": # Map to [-0.5,0.5] preprocess_transform = transforms.AffineScalarTransform(scale=(1.0 / 2 ** num_bits), shift=-0.5) elif preprocessing == "realnvp": preprocess_transform = transforms.CompositeTransform( [ # Map to [0,1] transforms.AffineScalarTransform(scale=(1.0 / 2 ** num_bits)), # Map into unconstrained space as done in RealNVP transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)), transforms.Logit(), ] ) elif preprocessing == "realnvp_2alpha": preprocess_transform = transforms.CompositeTransform( [ transforms.AffineScalarTransform(scale=(1.0 / 2 ** num_bits)), transforms.AffineScalarTransform(shift=alpha, scale=(1 - 2.0 * alpha)), transforms.Logit(), ] ) else: raise RuntimeError("Unknown preprocessing type: {}".format(preprocessing)) logger.debug(f"{preprocessing} preprocessing") # Multi-scale transform logger.debug("Input: c, h, w = %s, %s, %s", c, h, w) mct = transforms.MultiscaleCompositeTransform(num_transforms=levels) for level, level_hidden_channels in zip(range(levels), hidden_channels): logger.debug("Level %s", level) squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) logger.debug(" c, h, w = %s, %s, %s", c, h, w) transform_level = [squeeze_transform] logger.debug(" SqueezeTransform()") for _ in range(steps_per_level): transform_level.append( _make_image_base_transform( c, level_hidden_channels, act_norm, transform_type, residual_blocks, batch_norm, dropout_prob, tail_bound, bins, ) ) transform_level.append(transforms.OneByOneConvolution(c)) # End each level with a linear transformation logger.debug(" OneByOneConvolution(%s)", c) transform_level = transforms.CompositeTransform(transform_level) new_shape = mct.add_transform(transform_level, (c, h, w)) if new_shape: # If not last layer c, h, w = new_shape logger.debug(" new_shape = %s, %s, %s", c, h, w) # Full transform and flow transform = transforms.CompositeTransform([preprocess_transform, mct]) flow = flows.Flow(transform, base_dist) return flow
def _make_image_base_transform( num_channels, hidden_channels, actnorm, transform_type, num_res_blocks, resnet_batchnorm, dropout_prob, tail_bound, num_bins, apply_unconditional_transform=False, min_bin_width=0.001, min_bin_height=0.001, min_derivative=0.001, ): def convnet_factory(in_channels, out_channels): net = nets.ConvResidualNet( in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels, num_blocks=num_res_blocks, use_batch_norm=resnet_batchnorm, dropout_probability=dropout_prob, ) return net mask = flowutils.create_mid_split_binary_mask(num_channels) if transform_type == "cubic": coupling_layer = transforms.PiecewiseCubicCouplingTransform( mask=mask, transform_net_create_fn=convnet_factory, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=apply_unconditional_transform, min_bin_width=min_bin_width, min_bin_height=min_bin_height, ) elif transform_type == "quadratic": coupling_layer = transforms.PiecewiseQuadraticCouplingTransform( mask=mask, transform_net_create_fn=convnet_factory, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=apply_unconditional_transform, min_bin_width=min_bin_width, min_bin_height=min_bin_height, ) elif transform_type == "rq": coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform( mask=mask, transform_net_create_fn=convnet_factory, tails="linear", tail_bound=tail_bound, num_bins=num_bins, apply_unconditional_transform=apply_unconditional_transform, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, ) elif transform_type == "affine": coupling_layer = transforms.AffineCouplingTransform(mask=mask, transform_net_create_fn=convnet_factory) elif transform_type == "additive": coupling_layer = transforms.AdditiveCouplingTransform(mask=mask, transform_net_create_fn=convnet_factory) else: raise RuntimeError("Unknown transform type") step_transforms = [] if actnorm: step_transforms.append(transforms.ActNorm(num_channels)) step_transforms.extend([transforms.OneByOneConvolution(num_channels), coupling_layer]) transform = transforms.CompositeTransform(step_transforms) logger.debug(f" Block with {transform_type} coupling layers") return transform
def get_flow( model: str, dim_distribution: int, dim_context: Optional[int] = None, embedding: Optional[torch.nn.Module] = None, hidden_features: int = 50, made_num_mixture_components: int = 10, made_num_blocks: int = 4, flow_num_transforms: int = 5, mean=0.0, std=1.0, ) -> torch.nn.Module: """Density estimator Args: model: Model, one of maf / made / nsf dim_distribution: Dim of distribution dim_context: Dim of context embedding: Embedding network hidden_features: For all, number of hidden features made_num_mixture_components: For MADEs only, number of mixture components made_num_blocks: For MADEs only, number of blocks flow_num_transforms: For flows only, number of transforms mean: For normalization std: For normalization Returns: Neural network """ standardizing_transform = transforms.AffineTransform(shift=-mean / std, scale=1 / std) features = dim_distribution context_features = dim_context if model == "made": transform = standardizing_transform distribution = distributions_.MADEMoG( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=made_num_blocks, num_mixture_components=made_num_mixture_components, use_residual_blocks=True, random_mask=False, activation=torch.relu, dropout_probability=0.0, use_batch_norm=False, custom_initialization=True, ) neural_net = flows.Flow(transform, distribution, embedding) elif model == "maf": transform = transforms.CompositeTransform([ transforms.CompositeTransform([ transforms.MaskedAffineAutoregressiveTransform( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=2, use_residual_blocks=False, random_mask=False, activation=torch.tanh, dropout_probability=0.0, use_batch_norm=True, ), transforms.RandomPermutation(features=features), ]) for _ in range(flow_num_transforms) ]) transform = transforms.CompositeTransform( [standardizing_transform, transform]) distribution = distributions_.StandardNormal((features, )) neural_net = flows.Flow(transform, distribution, embedding) elif model == "nsf": transform = transforms.CompositeTransform([ transforms.CompositeTransform([ transforms.PiecewiseRationalQuadraticCouplingTransform( mask=create_alternating_binary_mask(features=features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nets.ResidualNet( in_features=in_features, out_features=out_features, hidden_features=hidden_features, context_features=context_features, num_blocks=2, activation=torch.relu, dropout_probability=0.0, use_batch_norm=False, ), num_bins=10, tails="linear", tail_bound=3.0, apply_unconditional_transform=False, ), transforms.LULinear(features, identity_init=True), ]) for i in range(flow_num_transforms) ]) transform = transforms.CompositeTransform( [standardizing_transform, transform]) distribution = distributions_.StandardNormal((features, )) neural_net = flows.Flow(transform, distribution, embedding) elif model == "nsf_bounded": transform = transforms.CompositeTransform([ transforms.CompositeTransform([ transforms.PiecewiseRationalQuadraticCouplingTransform( mask=create_alternating_binary_mask( features=dim_distribution, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nets.ResidualNet( in_features=in_features, out_features=out_features, hidden_features=hidden_features, context_features=context_features, num_blocks=2, activation=F.relu, dropout_probability=0.0, use_batch_norm=False, ), num_bins=10, tails="linear", tail_bound=np.sqrt( 3), # uniform with sqrt(3) bounds has unit-variance apply_unconditional_transform=False, ), transforms.RandomPermutation(features=dim_distribution), ]) for i in range(flow_num_transforms) ]) transform = transforms.CompositeTransform( [standardizing_transform, transform]) distribution = StandardUniform(shape=(dim_distribution, )) neural_net = flows.Flow(transform, distribution, embedding) else: raise ValueError return neural_net
def create_transform( num_flow_steps: int, param_dim: int, context_dim: int, base_transform_kwargs: dict, ): """Build a sequence of NSF transforms, which maps parameters x into the base distribution u (noise). Transforms are conditioned on strain data y. Note that the forward map is f^{-1}(x, y). Each step in the sequence consists of * A linear transform of x, which in particular permutes components * A NSF transform of x, conditioned on y. There is one final linear transform at the end. This function was adapted from the uci.py example in https://github.com/bayesiains/nsf Arguments: num_flow_steps {int} -- number of transforms in sequence param_dim {int} -- dimensionality of x context_dim {int} -- dimensionality of y base_transform_kwargs {dict} -- hyperparameters for NSF step Returns: Transform -- the constructed transform """ transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform(param_dim), create_base_transform( i, param_dim, context_dim=context_dim, **base_transform_kwargs) ]) for i in range(num_flow_steps) ] + [create_linear_transform(param_dim)]) # This architecture has been re-compartmentalized to have an initial linear # transform, followed by pairs of (NSF, linear) transforms. The architecture # should be exactly the same as in lfigw/nde_flows.py but intermediate layers # have been grouped differently for ease of visualising intermediate predictions. # transform = transforms.CompositeTransform([ # transforms.CompositeTransform([ # create_linear_transform(param_dim), # create_base_transform( # i, # param_dim, # context_dim=context_dim, # **base_transform_kwargs # ) # ]) for i in range(num_flow_steps-1) # ] + [transforms.CompositeTransform([ # create_linear_transform(param_dim), # create_base_transform( # num_flow_steps-1, # param_dim, # context_dim=context_dim, # **base_transform_kwargs # ), # create_linear_transform(param_dim) # ])] # ) return transform