def _create_vector_linear_transform(linear_transform_type, features): if linear_transform_type == "permutation": return transforms.RandomPermutation(features=features) elif linear_transform_type == "lu": return transforms.CompositeTransform([transforms.RandomPermutation(features=features), transforms.LULinear(features, identity_init=True)]) elif linear_transform_type == "svd": return transforms.CompositeTransform( [transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10, identity_init=True)] ) else: raise ValueError
def _create_postprocessing(dim, multi_scale, postprocessing, postprocessing_channel_factor, postprocessing_layers, res, context_features, tail_bound, num_bins): # TODO: take context_features into account here if postprocessing == "linear": final_transform = transforms.LULinear(dim, identity_init=True) logger.debug("LULinear(%s)", dim) elif postprocessing == "partial_linear": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 8], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transform = transforms.LULinear(partial_dim, identity_init=True) final_transform = transforms.PartialTransform(mask, partial_transform) logger.debug("PartialTransform (LULinear) (%s)", partial_dim) elif postprocessing == "partial_mlp": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 8], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transforms = [ transforms.LULinear(partial_dim, identity_init=True) ] logger.debug("PartialTransform (LULinear) (%s)", partial_dim) for _ in range(postprocessing_layers - 1): partial_transforms.append(transforms.LogTanh(cut_point=1)) logger.debug("PartialTransform (LogTanh) (%s)", partial_dim) partial_transforms.append( transforms.LULinear(partial_dim, identity_init=True)) logger.debug("PartialTransform (LULinear) (%s)", partial_dim) partial_transform = transforms.CompositeTransform(partial_transforms) final_transform = transforms.CompositeTransform([ transforms.PartialTransform(mask, partial_transform), transforms.MaskBasedPermutation(mask) ]) logging.debug("MaskBasedPermutation (%s)", mask) elif postprocessing == "partial_nsf": if multi_scale: mask = various.create_mlt_channel_mask( dim, channels_per_level=postprocessing_channel_factor * np.array([1, 2, 4, 16], dtype=np.int), resolution=res) partial_dim = torch.sum(mask.to(dtype=torch.int)).item() else: partial_dim = postprocessing_channel_factor * 1024 mask = various.create_split_binary_mask(dim, partial_dim) partial_transform = create_vector_transform( dim=partial_dim, flow_steps=postprocessing_layers, linear_transform_type="permutation", tail_bound=tail_bound, num_bins=num_bins) logging.debug("RQ-NSF transform on %s features with %s steps", partial_dim, postprocessing_layers) final_transform = transforms.CompositeTransform([ transforms.PartialTransform(mask, partial_transform), transforms.MaskBasedPermutation(mask) ]) logging.debug("MaskBasedPermutation (%s)", mask) elif postprocessing == "permutation": # Random permutation final_transform = transforms.RandomPermutation(dim) logger.debug("RandomPermutation(%s)", dim) elif postprocessing == "none": final_transform = transforms.IdentityTransform() else: raise NotImplementedError(postprocessing) return final_transform
def __init__(self, num_channels, using_cache=False, identity_init=True): super().__init__(num_channels, using_cache, identity_init) self.permutation = transforms.RandomPermutation(num_channels, dim=1)
def create_image_transform( c, h, w, levels=3, hidden_channels=96, steps_per_level=7, alpha=0.05, num_bits=8, preprocessing="glow", multi_scale=True, use_resnet=True, dropout_prob=0.0, num_res_blocks=3, coupling_layer_type="rational_quadratic_spline", use_batchnorm=False, use_actnorm=True, spline_params=None, ): dim = c * h * w if not isinstance(hidden_channels, list): hidden_channels = [hidden_channels] * levels if multi_scale: mct = transforms.MultiscaleCompositeTransform(num_transforms=levels) for level, level_hidden_channels in zip(range(levels), hidden_channels): logger.debug("Level %s", level) squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) logger.debug(" c, h, w = %s, %s, %s", c, h, w) logger.debug(" SqueezeTransform()") transform_level = transforms.CompositeTransform( [squeeze_transform] + [ _create_image_transform_step( c, level_hidden_channels, actnorm=use_actnorm, coupling_layer_type=coupling_layer_type, spline_params=spline_params, use_resnet=use_resnet, num_res_blocks=num_res_blocks, resnet_batchnorm=use_batchnorm, dropout_prob=dropout_prob, ) for _ in range(steps_per_level) ] + [transforms.OneByOneConvolution(c) ] # End each level with a linear transformation. ) logger.debug(" OneByOneConvolution(%s)", c) new_shape = mct.add_transform(transform_level, (c, h, w)) if new_shape: # If not last layer c, h, w = new_shape logger.debug(" new_shape = %s, %s, %s", c, h, w) else: all_transforms = [] for level, level_hidden_channels in zip(range(levels), hidden_channels): squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) transform_level = transforms.CompositeTransform( [squeeze_transform] + [ _create_image_transform_step( c, level_hidden_channels, actnorm=use_actnorm, coupling_layer_type=coupling_layer_type, spline_params=spline_params, use_resnet=use_resnet, num_res_blocks=num_res_blocks, resnet_batchnorm=use_batchnorm, dropout_prob=dropout_prob, ) for _ in range(steps_per_level) ] + [transforms.OneByOneConvolution(c) ] # End each level with a linear transformation. ) all_transforms.append(transform_level) all_transforms.append( transforms.ReshapeTransform(input_shape=(c, h, w), output_shape=(c * h * w, ))) mct = transforms.CompositeTransform(all_transforms) # Inputs to the model in [0, 2 ** num_bits] if preprocessing == "glow": # Map to [-0.5,0.5] preprocess_transform = transforms.AffineScalarTransform( scale=(1.0 / 2**num_bits), shift=-0.5) elif preprocessing == "realnvp": preprocess_transform = transforms.CompositeTransform([ # Map to [0,1] transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)), # Map into unconstrained space as done in RealNVP transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)), transforms.Logit(), ]) elif preprocessing == "realnvp_2alpha": preprocess_transform = transforms.CompositeTransform([ transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)), transforms.AffineScalarTransform(shift=alpha, scale=(1 - 2.0 * alpha)), transforms.Logit(), ]) else: raise RuntimeError( "Unknown preprocessing type: {}".format(preprocessing)) # Random permutation permutation = transforms.RandomPermutation(dim) logger.debug("RandomPermutation(%s)", dim) return transforms.CompositeTransform( [preprocess_transform, mct, permutation])