def _create_vector_linear_transform(linear_transform_type, features): if linear_transform_type == "permutation": return transforms.RandomPermutation(features=features) elif linear_transform_type == "lu": return transforms.CompositeTransform([ transforms.RandomPermutation(features=features), transforms.LULinear(features, identity_init=True) ]) elif linear_transform_type == "svd": return transforms.CompositeTransform([ transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10) ]) else: raise ValueError
def _create_postprocessing(dim, multi_scale, postprocessing, postprocessing_channel_factor, postprocessing_layers, res, context_features, tail_bound, num_bins): # TODO: take context_features into account here if postprocessing == "linear": final_transform = transforms.LULinear(dim, identity_init=True) logger.debug("LULinear(%s)", dim) elif postprocessing == "partial_linear": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 8], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transform = transforms.LULinear(partial_dim, identity_init=True) final_transform = transforms.PartialTransform(mask, partial_transform) logger.debug("PartialTransform (LULinear) (%s)", partial_dim) elif postprocessing == "partial_mlp": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 8], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transforms = [ transforms.LULinear(partial_dim, identity_init=True) ] logger.debug("PartialTransform (LULinear) (%s)", partial_dim) for _ in range(postprocessing_layers - 1): partial_transforms.append(transforms.LogTanh(cut_point=1)) logger.debug("PartialTransform (LogTanh) (%s)", partial_dim) partial_transforms.append( transforms.LULinear(partial_dim, identity_init=True)) logger.debug("PartialTransform (LULinear) (%s)", partial_dim) partial_transform = transforms.CompositeTransform(partial_transforms) final_transform = transforms.CompositeTransform([ transforms.PartialTransform(mask, partial_transform), transforms.MaskBasedPermutation(mask) ]) logging.debug("MaskBasedPermutation (%s)", mask) elif postprocessing == "partial_nsf": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 16], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transform = create_vector_transform( dim=partial_dim, flow_steps=postprocessing_layers, linear_transform_type="permutation", tail_bound=tail_bound, num_bins=num_bins) logging.debug("RQ-NSF transform on %s features with %s steps", partial_dim, postprocessing_layers) final_transform = transforms.CompositeTransform([ transforms.PartialTransform(mask, partial_transform), transforms.MaskBasedPermutation(mask) ]) logging.debug("MaskBasedPermutation (%s)", mask) elif postprocessing == "permutation": # Random permutation final_transform = transforms.RandomPermutation(dim) logger.debug("RandomPermutation(%s)", dim) elif postprocessing == "none": final_transform = transforms.IdentityTransform() else: raise NotImplementedError(postprocessing) return final_transform