def validate_conv2d_input(input, channels_last, arg_name='input'): """ Validate the input for 2-d convolution. Args: input: The input tensor, must be at least 4-d. channels_last (bool): Whether or not the last dimension is the channels dimension? (i.e., `data_format` is "NHWC") arg_name (str): Name of the input argument. Returns: (tf.Tensor, int, str): The validated input tensor, the number of input channels, and the data format. """ if channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) channel_axis = -1 data_format = 'NHWC' else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) channel_axis = -3 data_format = 'NCHW' input = input_spec.validate(arg_name, input) input_shape = get_static_shape(input) in_channels = input_shape[channel_axis] return input, in_channels, data_format
def transpose_conv2d_axis(input, from_channels_last, to_channels_last, name=None): """ Ensure the channels axis of `input` tensor to be placed at the desired axis. Args: input (tf.Tensor): The input tensor, at least 4-d. from_channels_last (bool): Whether or not the channels axis is the last axis in `input`? (i.e., the data format is "NHWC") to_channels_last (bool): Whether or not the channels axis should be the last axis in the output tensor? Returns: tf.Tensor: The (maybe) transposed output tensor. """ if from_channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) input = input_spec.validate('input', input) input_shape = get_static_shape(input) sample_and_batch_axis = [i for i in range(len(input_shape) - 3)] # check whether or not axis should be transpose if from_channels_last and not to_channels_last: transpose_axis = [-1, -3, -2] elif not from_channels_last and to_channels_last: transpose_axis = [-2, -1, -3] else: transpose_axis = None # transpose the axis if transpose_axis is not None: transpose_axis = [i + len(input_shape) for i in transpose_axis] input = tf.transpose(input, sample_and_batch_axis + transpose_axis, name=name or 'transpose_conv2d_axis') return input
class BaseFlow(BaseLayer): """ The basic class for normalizing flows. A normalizing flow transforms a random variable `x` into `y` by an (implicitly) invertible mapping :math:`y = f(x)`, whose Jaccobian matrix determinant :math:`\\det \\frac{\\partial f(x)}{\\partial x} \\neq 0`, thus can derive :math:`\\log p(y)` from given :math:`\\log p(x)`. """ @add_name_and_scope_arg_doc def __init__(self, x_value_ndims, y_value_ndims=None, require_batch_dims=False, name=None, scope=None): """ Construct a new :class:`BaseFlow`. Args: x_value_ndims (int): Number of value dimensions in `x`. `x.ndims - x_value_ndims == log_det.ndims`. y_value_ndims (int): Number of value dimensions in `y`. `y.ndims - y_value_ndims == log_det.ndims`. If not specified, use `x_value_ndims`. require_batch_dims (bool): If :obj:`True`, `x` are required to have at least `x_value_ndims + 1` dimensions, and `y` are required to have at least `y_value_ndims + 1` dimensions. If :obj:`False`, `x` are required to have at least `x_value_ndims` dimensions, and `y` are required to have at least `y_value_ndims` dimensions. """ x_value_ndims = int(x_value_ndims) if y_value_ndims is None: y_value_ndims = x_value_ndims else: y_value_ndims = int(y_value_ndims) super(BaseFlow, self).__init__(name=name, scope=scope) self._x_value_ndims = x_value_ndims self._y_value_ndims = y_value_ndims self._require_batch_dims = bool(require_batch_dims) self._x_input_spec = None # type: InputSpec self._y_input_spec = None # type: InputSpec def invert(self): """ Get the inverted flow from this flow. The :meth:`transform()` will become the :meth:`inverse_transform()` in the inverted flow, and the :meth:`inverse_transform()` will become the :meth:`transform()` in the inverted flow. If the current flow has not been initialized, it must be initialized via :meth:`inverse_transform()` in the new flow. Returns: tfsnippet.layers.InvertFlow: The inverted flow. """ from .invert import InvertFlow return InvertFlow(self) @property def x_value_ndims(self): """ Get the number of value dimensions in `x`. Returns: int: The number of value dimensions in `x`. """ return self._x_value_ndims @property def y_value_ndims(self): """ Get the number of value dimensions in `y`. Returns: int: The number of value dimensions in `y`. """ return self._y_value_ndims @property def require_batch_dims(self): """Whether or not this flow requires batch dimensions.""" return self._require_batch_dims @property def explicitly_invertible(self): """ Whether or not this flow is explicitly invertible? If a flow is not explicitly invertible, then it only supports to transform `x` into `y`, and corresponding :math:`\\log p(x)` into :math:`\\log p(y)`. It cannot compute :math:`\\log p(y)` directly without knowing `x`, nor can it transform `x` back into `y`. Returns: bool: A boolean indicating whether or not the flow is explicitly invertible. """ raise NotImplementedError() def _build_input_spec(self, input): batch_ndims = int(self.require_batch_dims) dtype = input.dtype.base_dtype x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims) y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims) self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype) self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype) def build(self, input=None): # check the input. if input is None: raise ValueError('`input` is required to build {}.'.format( self.__class__.__name__)) input = tf.convert_to_tensor(input) shape = get_static_shape(input) require_ndims = self.x_value_ndims + int(self.require_batch_dims) require_ndims_text = ('x_value_ndims + 1' if self.require_batch_dims else 'x_value_ndims') if shape is None or len(shape) < require_ndims: raise ValueError('`x.ndims` must be known and >= `{}`: x ' '{} vs ndims `{}`.'.format( require_ndims_text, input, require_ndims)) # build the input spec self._build_input_spec(input) # build the layer return super(BaseFlow, self).build(input) def _transform(self, x, compute_y, compute_log_det): raise NotImplementedError() def transform(self, x, compute_y=True, compute_log_det=True, name=None): """ Transform `x` into `y`, and compute the log-determinant of `f` at `x` (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`). Args: x (Tensor): The samples of `x`. compute_y (bool): Whether or not to compute :math:`y = f(x)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_y` and `compute_log_det` are set to :obj:`False`. """ if not compute_y and not compute_log_det: raise ValueError('At least one of `compute_y` and ' '`compute_log_det` should be True.') x = tf.convert_to_tensor(x) if not self._has_built: self.build(x) x = self._x_input_spec.validate('x', x) with tf.name_scope(name, default_name=get_default_scope_name( 'transform', self), values=[x]): y, log_det = self._transform(x, compute_y, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=x, value_ndims=self.x_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) return y, log_det def _inverse_transform(self, y, compute_x, compute_log_det): raise NotImplementedError() def inverse_transform(self, y, compute_x=True, compute_log_det=True, name=None): """ Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at `y` (i.e., :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`). Args: y (Tensor): The samples of `y`. compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_x` and `compute_log_det` are set to :obj:`False`. RuntimeError: If the flow is not explicitly invertible. """ if not self.explicitly_invertible: raise RuntimeError( 'The flow is not explicitly invertible: {!r}'.format(self)) if not compute_x and not compute_log_det: raise ValueError('At least one of `compute_x` and ' '`compute_log_det` should be True.') if not self._has_built: raise RuntimeError('`inverse_transform` cannot be called before ' 'the flow has been built; it can be built by ' 'calling `build`, `apply` or `transform`: ' '{!r}'.format(self)) y = tf.convert_to_tensor(y) y = self._y_input_spec.validate('y', y) with tf.name_scope(name, default_name=get_default_scope_name( 'inverse_transform', self), values=[y]): x, log_det = self._inverse_transform(y, compute_x, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=y, value_ndims=self.y_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) return x, log_det def _apply(self, x): y, _ = self.transform(x, compute_y=True, compute_log_det=False) return y
def resnet_deconv2d_block(input, out_channels, kernel_size, strides=(1, 1), shortcut_kernel_size=(1, 1), channels_last=True, resize_at_exit=False, activation_fn=None, normalizer_fn=None, weight_norm=False, dropout_fn=None, kernel_initializer=None, kernel_regularizer=None, kernel_constraint=None, use_bias=None, bias_initializer=tf.zeros_initializer(), bias_regularizer=None, bias_constraint=None, trainable=True, name=None, scope=None): """ 2D deconvolutional ResNet block. Args: input (Tensor): The input tensor, at least 4-d. out_channels (int): The channel numbers of the output. kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for "conv" and "conv_1" deconvolutional layers. strides (int or tuple[int]): Strides over spatial dimensions, for all three deconvolutional layers. shortcut_kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for the "shortcut" deconvolutional layer. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") resize_at_exit (bool): See :func:`resnet_general_block`. activation_fn: The activation function. normalizer_fn: The normalizer function. weight_norm: Passed to :func:`deconv2d`. dropout_fn: The dropout function. kernel_initializer: Passed to :func:`deconv2d`. kernel_regularizer: Passed to :func:`deconv2d`. kernel_constraint: Passed to :func:`deconv2d`. use_bias: Whether or not to use `bias` in :func:`deconv2d`? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. bias_initializer: Passed to :func:`deconv2d`. bias_regularizer: Passed to :func:`deconv2d`. bias_constraint: Passed to :func:`deconv2d`. trainable: Passed to :func:`convdeconv2d2d`. Returns: tf.Tensor: The output tensor. See Also: :func:`resnet_general_block` """ # check the input and infer the input shape if channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) c_axis = -1 else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) c_axis = -3 input = input_spec.validate('input', input) in_channels = get_static_shape(input)[c_axis] # check the functional arguments if use_bias is None: use_bias = normalizer_fn is None # derive the convolution function conv_fn = partial( deconv2d, channels_last=channels_last, weight_norm=weight_norm, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, kernel_constraint=kernel_constraint, use_bias=use_bias, bias_initializer=bias_initializer, bias_regularizer=bias_regularizer, bias_constraint=bias_constraint, trainable=trainable, ) # build the resnet block return resnet_general_block(conv_fn, input=input, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, strides=strides, shortcut_kernel_size=shortcut_kernel_size, resize_at_exit=resize_at_exit, activation_fn=activation_fn, normalizer_fn=normalizer_fn, dropout_fn=dropout_fn, name=name or 'resnet_deconv2d_block', scope=scope)
def pixelcnn_2d_sample(fn, inputs, height, width, channels_last=True, start=0, end=None, back_prop=False, parallel_iterations=1, swap_memory=False, name=None): """ Sample output from a PixelCNN 2D network, pixel-by-pixel. Args: fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`, the function to derive the outputs of PixelCNN 2D network at iteration `i`. `inputs` are the pixel-by-pixel outputs gathered through iteration `0` to iteration `i - 1`. The iteration index `i` may range from `0` to `height * width - 1`. inputs (Iterable[tf.Tensor]): The initial input tensors. All the tensors must be at least 4-d, with identical shape. height (int or tf.Tensor): The height of the outputs. width (int or tf.Tensor): The width of the outputs. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") start (int or tf.Tensor): The start iteration, default `0`. end (int or tf.Tensor): The end (exclusive) iteration. Default `height * width`. back_prop, parallel_iterations, swap_memory: Arguments passed to :func:`tf.while_loop`. Returns: tuple[tf.Tensor]: The final outputs. """ from tfsnippet.layers.convolutional.utils import validate_conv2d_input # check the arguments def to_int(t): if is_tensor_object(t): return convert_to_tensor_and_cast(t, dtype=tf.int32) return int(t) height = to_int(height) width = to_int(width) inputs = list(inputs) if not inputs: raise ValueError('`inputs` must not be empty.') inputs[0], _, _ = validate_conv2d_input(inputs[0], channels_last=channels_last, arg_name='inputs[0]') input_spec = InputSpec(shape=get_static_shape(inputs[0])) for i, input in enumerate(inputs[1:], 1): inputs[i] = input_spec.validate('inputs[{}]'.format(i), input) # do pixelcnn sampling with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs): # the total size, start and end index total_size = height * width start = convert_to_tensor_and_cast(start, dtype=tf.int32) if end is None: end = convert_to_tensor_and_cast(total_size, dtype=tf.int32) else: end = convert_to_tensor_and_cast(end, dtype=tf.int32) # the mask shape if channels_last: mask_shape = [height, width, 1] else: mask_shape = [height, width] if any(is_tensor_object(t) for t in mask_shape): mask_shape = tf.stack(mask_shape, axis=0) # the input dynamic shape input_shape = get_shape(inputs[0]) # the pixelcnn sampling loop def loop_cond(idx, _): return idx < end def loop_body(idx, inputs): inputs = tuple(inputs) # prepare for the output mask selector = tf.reshape( tf.concat([ tf.ones([idx], dtype=tf.uint8), tf.zeros([1], dtype=tf.uint8), tf.ones([total_size - idx - 1], dtype=tf.uint8) ], axis=0), mask_shape) selector = tf.cast(broadcast_to_shape(selector, input_shape), dtype=tf.bool) # obtain the outputs outputs = list(fn(idx, inputs)) if len(outputs) != len(inputs): raise ValueError( 'The length of outputs != inputs: {} vs {}'.format( len(outputs), len(inputs))) # mask the outputs for i, (input, output) in enumerate(zip(inputs, outputs)): input_dtype = inputs[i].dtype.base_dtype output_dtype = output.dtype.base_dtype if output_dtype != input_dtype: raise TypeError( '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: ' '{output} vs {input}'.format(idx=i, output=output_dtype, input=input_dtype)) outputs[i] = tf.where(selector, input, output) return idx + 1, tuple(outputs) i0 = start _, outputs = tf.while_loop( cond=loop_cond, body=loop_body, loop_vars=(i0, tuple(inputs)), back_prop=back_prop, parallel_iterations=parallel_iterations, swap_memory=swap_memory, ) return outputs