def build(self, input_shape): """Build `Layer`""" input_shape = tensor_shape.TensorShape(input_shape).as_list() self.input_spec = base_layer.InputSpec(shape=input_shape) if not self.layer.built: self.layer.build(input_shape) self.layer.built = False if not hasattr(self.layer, 'kernel'): raise ValueError('`WeightNormalization` must wrap a layer that' ' contains a `kernel` for weights') # The kernel's filter or unit dimension is -1 self.layer_depth = int(self.layer.kernel.shape[-1]) self.kernel_norm_axes = list( range(self.layer.kernel.shape.rank - 1)) self.layer.v = self.layer.kernel self.layer.g = self.layer.add_variable( name="g", shape=(self.layer_depth, ), initializer=initializers.get('ones'), dtype=self.layer.kernel.dtype, trainable=True, aggregation=tf_variables.VariableAggregation.MEAN) # TODO: Check if this needs control deps in TF2 graph mode self.layer.g.assign(self._init_norm(self.layer.v)) self._compute_weights() self.layer.built = True super(WeightNormalization, self).build() self.built = True
def build(self, input_shape): """Builds the layer. Creates the variables for the network modeling the densities, creates the auxiliary loss estimating the median and tail quantiles of the densities, and then uses that to create the probability mass functions and the update op that produces the discrete cumulative density functions used by the range coder. Args: input_shape: Shape of the input tensor, used to get the number of channels. Raises: ValueError: if `input_shape` doesn't specify the length of the channel dimension. """ input_shape = tensor_shape.TensorShape(input_shape) channel_axis = self._channel_axis(input_shape.ndims) channels = input_shape[channel_axis].value self.n = input_shape[0].value self.h = input_shape[1].value self.w = input_shape[2].value self.c = input_shape[3].value if channels is None: raise ValueError("The channel dimension of the inputs must be defined.") self.input_spec = base_layer.InputSpec( ndim=input_shape.ndims, axes={channel_axis: channels}) super(EntropyBottleneck_gauss, self).build(input_shape)
def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, optimize_integer_offset=True, likelihood_bound=1e-9, range_coder_precision=16, data_format="channels_last", **kwargs): super(EntropyBottleneck, self).__init__(**kwargs) self._init_scale = float(init_scale) self._filters = tuple(int(f) for f in filters) self._tail_mass = float(tail_mass) if not 0 < self.tail_mass < 1: raise ValueError( "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) self._optimize_integer_offset = bool(optimize_integer_offset) self._likelihood_bound = float(likelihood_bound) self._range_coder_precision = int(range_coder_precision) self._data_format = data_format self._channel_axis(2) # trigger ValueError early self.input_spec = base_layer.InputSpec(min_ndim=2)
def build(self, input_shape): """Builds the layer. Creates the variables for the network modeling the densities, creates the auxiliary loss estimating the median and tail quantiles of the densities, and then uses that to create the probability mass functions and the update op that produces the discrete cumulative density functions used by the range coder. Args: input_shape: Shape of the input tensor, used to get the number of channels. Raises: ValueError: if `input_shape` doesn't specify the length of the channel dimension. """ input_shape = tensor_shape.TensorShape(input_shape) channel_axis = self._channel_axis(input_shape.ndims) channels = input_shape[channel_axis].value if channels is None: raise ValueError( "The channel dimension of the inputs must be defined.") self.input_spec = base_layer.InputSpec(ndim=input_shape.ndims, axes={channel_axis: channels}) filters = (1, ) + self.filters + (1, ) scale = self.init_scale**(1 / (len(self.filters) + 1)) # Create variables. self._matrices = [] self._biases = [] self._factors = [] for i in range(len(self.filters) + 1): init = np.log(np.expm1(1 / scale / filters[i + 1])) matrix = self.add_variable("matrix_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], filters[i]), initializer=init_ops.Constant(init)) matrix = nn.softplus(matrix) self._matrices.append(matrix) bias = self.add_variable("bias_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], 1), initializer=init_ops.RandomUniform( -.5, .5)) self._biases.append(bias) if i < len(self.filters): factor = self.add_variable("factor_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], 1), initializer=init_ops.Zeros()) factor = math_ops.tanh(factor) self._factors.append(factor) # To figure out what range of the densities to sample, we need to compute # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we # can't take inverses of the cumulative directly, we make it an optimization # problem: # `quantiles = argmin(|logit(cumulative) - target|)` # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. # Taking the logit (inverse of sigmoid) of the cumulative makes the # representation of the right target more numerically stable. # Numerically stable way of computing logits of `tail_mass / 2` # and `1 - tail_mass / 2`. target = np.log(2 / self.tail_mass - 1) # Compute lower and upper tail quantile as well as median. target = constant_op.constant([-target, 0, target], dtype=self.dtype) def quantiles_initializer(shape, dtype=None, partition_info=None): del partition_info # unused assert tuple(shape[1:]) == (1, 3) init = constant_op.constant( [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) return array_ops.tile(init, (shape[0], 1, 1)) quantiles = self.add_variable("quantiles", shape=(channels, 1, 3), dtype=self.dtype, initializer=quantiles_initializer) logits = self._logits_cumulative(quantiles, stop_gradient=True) loss = math_ops.reduce_sum(abs(logits - target)) self.add_loss(loss, inputs=None) # Save medians for `call`, `compress`, and `decompress`. self._medians = quantiles[:, :, 1:2] if not self.optimize_integer_offset: self._medians = math_ops.round(self._medians) # Largest distance observed between lower tail quantile and median, # or between median and upper tail quantile. minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) minmax = math_ops.maximum(minima, maxima) minmax = math_ops.ceil(minmax) minmax = math_ops.maximum(minmax, 1) # Sample the density up to `minmax` around the median. samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) samples += self._medians half = constant_op.constant(.5, dtype=self.dtype) # We strip the sigmoid from the end here, so we can use the special rule # below to only compute differences in the left tail of the sigmoid. # This increases numerical stability (see explanation in `call`). lower = self._logits_cumulative(samples - half, stop_gradient=True) upper = self._logits_cumulative(samples + half, stop_gradient=True) # Flip signs if we can move more towards the left tail of the sigmoid. sign = -math_ops.sign(math_ops.add_n([lower, upper])) pmf = abs( math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) # Add tail masses to first and last bin of pmf, as we clip values for # compression, meaning that out-of-range values get mapped to these bins. pmf = array_ops.concat([ math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), pmf[:, 0, 1:-1], math_ops.add_n( [pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), ], axis=-1) self._pmf = pmf cdf = coder_ops.pmf_to_quantized_cdf( pmf, precision=self.range_coder_precision) # We need to supply an initializer without fully defined static shape here, # or the variable will return the wrong dynamic shape later. A placeholder # with default gets the trick done. def cdf_init(*args, **kwargs): del args, kwargs # unused return array_ops.placeholder_with_default(array_ops.zeros( (channels, 1), dtype=dtypes.int32), shape=(channels, None)) self._quantized_cdf = self.add_variable("quantized_cdf", shape=None, initializer=cdf_init, dtype=dtypes.int32, trainable=False) update_op = state_ops.assign(self._quantized_cdf, cdf, validate_shape=False) self.add_update(update_op, inputs=None) super(EntropyBottleneck, self).build(input_shape)