Beispiel #1
0
def validate_conv2d_input(input, channels_last, arg_name='input'):
    """
    Validate the input for 2-d convolution.

    Args:
        input: The input tensor, must be at least 4-d.
        channels_last (bool): Whether or not the last dimension is the
            channels dimension? (i.e., `data_format` is "NHWC")
        arg_name (str): Name of the input argument.

    Returns:
        (tf.Tensor, int, str): The validated input tensor, the number of input
            channels, and the data format.
    """
    if channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
        channel_axis = -1
        data_format = 'NHWC'
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
        channel_axis = -3
        data_format = 'NCHW'
    input = input_spec.validate(arg_name, input)
    input_shape = get_static_shape(input)
    in_channels = input_shape[channel_axis]

    return input, in_channels, data_format
Beispiel #2
0
def transpose_conv2d_axis(input,
                          from_channels_last,
                          to_channels_last,
                          name=None):
    """
    Ensure the channels axis of `input` tensor to be placed at the desired axis.

    Args:
        input (tf.Tensor): The input tensor, at least 4-d.
        from_channels_last (bool): Whether or not the channels axis
            is the last axis in `input`? (i.e., the data format is "NHWC")
        to_channels_last (bool): Whether or not the channels axis
            should be the last axis in the output tensor?

    Returns:
        tf.Tensor: The (maybe) transposed output tensor.
    """
    if from_channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
    input = input_spec.validate('input', input)
    input_shape = get_static_shape(input)
    sample_and_batch_axis = [i for i in range(len(input_shape) - 3)]

    # check whether or not axis should be transpose
    if from_channels_last and not to_channels_last:
        transpose_axis = [-1, -3, -2]
    elif not from_channels_last and to_channels_last:
        transpose_axis = [-2, -1, -3]
    else:
        transpose_axis = None

    # transpose the axis
    if transpose_axis is not None:
        transpose_axis = [i + len(input_shape) for i in transpose_axis]
        input = tf.transpose(input,
                             sample_and_batch_axis + transpose_axis,
                             name=name or 'transpose_conv2d_axis')

    return input
Beispiel #3
0
class BaseFlow(BaseLayer):
    """
    The basic class for normalizing flows.

    A normalizing flow transforms a random variable `x` into `y` by an
    (implicitly) invertible mapping :math:`y = f(x)`, whose Jaccobian matrix
    determinant :math:`\\det \\frac{\\partial f(x)}{\\partial x} \\neq 0`, thus
    can derive :math:`\\log p(y)` from given :math:`\\log p(x)`.
    """
    @add_name_and_scope_arg_doc
    def __init__(self,
                 x_value_ndims,
                 y_value_ndims=None,
                 require_batch_dims=False,
                 name=None,
                 scope=None):
        """
        Construct a new :class:`BaseFlow`.

        Args:
            x_value_ndims (int): Number of value dimensions in `x`.
                `x.ndims - x_value_ndims == log_det.ndims`.
            y_value_ndims (int): Number of value dimensions in `y`.
                `y.ndims - y_value_ndims == log_det.ndims`.
                If not specified, use `x_value_ndims`.
            require_batch_dims (bool): If :obj:`True`, `x` are required
                to have at least `x_value_ndims + 1` dimensions, and `y`
                are required to have at least `y_value_ndims + 1` dimensions.

                If :obj:`False`, `x` are required to have at least
                `x_value_ndims` dimensions, and `y` are required to have
                at least `y_value_ndims` dimensions.
        """
        x_value_ndims = int(x_value_ndims)
        if y_value_ndims is None:
            y_value_ndims = x_value_ndims
        else:
            y_value_ndims = int(y_value_ndims)

        super(BaseFlow, self).__init__(name=name, scope=scope)
        self._x_value_ndims = x_value_ndims
        self._y_value_ndims = y_value_ndims
        self._require_batch_dims = bool(require_batch_dims)

        self._x_input_spec = None  # type: InputSpec
        self._y_input_spec = None  # type: InputSpec

    def invert(self):
        """
        Get the inverted flow from this flow.

        The :meth:`transform()` will become the :meth:`inverse_transform()`
        in the inverted flow, and the :meth:`inverse_transform()` will become
        the :meth:`transform()` in the inverted flow.

        If the current flow has not been initialized, it must be initialized
        via :meth:`inverse_transform()` in the new flow.

        Returns:
            tfsnippet.layers.InvertFlow: The inverted flow.
        """
        from .invert import InvertFlow
        return InvertFlow(self)

    @property
    def x_value_ndims(self):
        """
        Get the number of value dimensions in `x`.

        Returns:
            int: The number of value dimensions in `x`.
        """
        return self._x_value_ndims

    @property
    def y_value_ndims(self):
        """
        Get the number of value dimensions in `y`.

        Returns:
            int: The number of value dimensions in `y`.
        """
        return self._y_value_ndims

    @property
    def require_batch_dims(self):
        """Whether or not this flow requires batch dimensions."""
        return self._require_batch_dims

    @property
    def explicitly_invertible(self):
        """
        Whether or not this flow is explicitly invertible?

        If a flow is not explicitly invertible, then it only supports to
        transform `x` into `y`, and corresponding :math:`\\log p(x)` into
        :math:`\\log p(y)`.  It cannot compute :math:`\\log p(y)` directly
        without knowing `x`, nor can it transform `x` back into `y`.

        Returns:
            bool: A boolean indicating whether or not the flow is explicitly
                invertible.
        """
        raise NotImplementedError()

    def _build_input_spec(self, input):
        batch_ndims = int(self.require_batch_dims)
        dtype = input.dtype.base_dtype

        x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims)
        y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims)

        self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype)
        self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype)

    def build(self, input=None):
        # check the input.
        if input is None:
            raise ValueError('`input` is required to build {}.'.format(
                self.__class__.__name__))

        input = tf.convert_to_tensor(input)
        shape = get_static_shape(input)
        require_ndims = self.x_value_ndims + int(self.require_batch_dims)
        require_ndims_text = ('x_value_ndims + 1'
                              if self.require_batch_dims else 'x_value_ndims')

        if shape is None or len(shape) < require_ndims:
            raise ValueError('`x.ndims` must be known and >= `{}`: x '
                             '{} vs ndims `{}`.'.format(
                                 require_ndims_text, input, require_ndims))

        # build the input spec
        self._build_input_spec(input)

        # build the layer
        return super(BaseFlow, self).build(input)

    def _transform(self, x, compute_y, compute_log_det):
        raise NotImplementedError()

    def transform(self, x, compute_y=True, compute_log_det=True, name=None):
        """
        Transform `x` into `y`, and compute the log-determinant of `f` at `x`
        (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`).

        Args:
            x (Tensor): The samples of `x`.
            compute_y (bool): Whether or not to compute :math:`y = f(x)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_y` and `compute_log_det` are set
                to :obj:`False`.
        """
        if not compute_y and not compute_log_det:
            raise ValueError('At least one of `compute_y` and '
                             '`compute_log_det` should be True.')

        x = tf.convert_to_tensor(x)
        if not self._has_built:
            self.build(x)

        x = self._x_input_spec.validate('x', x)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'transform', self),
                           values=[x]):
            y, log_det = self._transform(x, compute_y, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=x,
                            value_ndims=self.x_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            return y, log_det

    def _inverse_transform(self, y, compute_x, compute_log_det):
        raise NotImplementedError()

    def inverse_transform(self,
                          y,
                          compute_x=True,
                          compute_log_det=True,
                          name=None):
        """
        Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at
        `y` (i.e.,
        :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`).

        Args:
            y (Tensor): The samples of `y`.
            compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_x` and `compute_log_det` are set
                to :obj:`False`.
            RuntimeError: If the flow is not explicitly invertible.
        """
        if not self.explicitly_invertible:
            raise RuntimeError(
                'The flow is not explicitly invertible: {!r}'.format(self))
        if not compute_x and not compute_log_det:
            raise ValueError('At least one of `compute_x` and '
                             '`compute_log_det` should be True.')
        if not self._has_built:
            raise RuntimeError('`inverse_transform` cannot be called before '
                               'the flow has been built; it can be built by '
                               'calling `build`, `apply` or `transform`: '
                               '{!r}'.format(self))

        y = tf.convert_to_tensor(y)
        y = self._y_input_spec.validate('y', y)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'inverse_transform', self),
                           values=[y]):
            x, log_det = self._inverse_transform(y, compute_x, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=y,
                            value_ndims=self.y_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            return x, log_det

    def _apply(self, x):
        y, _ = self.transform(x, compute_y=True, compute_log_det=False)
        return y
Beispiel #4
0
def resnet_deconv2d_block(input,
                          out_channels,
                          kernel_size,
                          strides=(1, 1),
                          shortcut_kernel_size=(1, 1),
                          channels_last=True,
                          resize_at_exit=False,
                          activation_fn=None,
                          normalizer_fn=None,
                          weight_norm=False,
                          dropout_fn=None,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          kernel_constraint=None,
                          use_bias=None,
                          bias_initializer=tf.zeros_initializer(),
                          bias_regularizer=None,
                          bias_constraint=None,
                          trainable=True,
                          name=None,
                          scope=None):
    """
    2D deconvolutional ResNet block.

    Args:
        input (Tensor): The input tensor, at least 4-d.
        out_channels (int): The channel numbers of the output.
        kernel_size (int or tuple[int]): Kernel size over spatial dimensions,
            for "conv" and "conv_1" deconvolutional layers.
        strides (int or tuple[int]): Strides over spatial dimensions,
            for all three deconvolutional layers.
        shortcut_kernel_size (int or tuple[int]): Kernel size over spatial
            dimensions, for the "shortcut" deconvolutional layer.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        resize_at_exit (bool): See :func:`resnet_general_block`.
        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        weight_norm: Passed to :func:`deconv2d`.
        dropout_fn: The dropout function.
        kernel_initializer: Passed to :func:`deconv2d`.
        kernel_regularizer: Passed to :func:`deconv2d`.
        kernel_constraint: Passed to :func:`deconv2d`.
        use_bias: Whether or not to use `bias` in :func:`deconv2d`?
            If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        bias_initializer: Passed to :func:`deconv2d`.
        bias_regularizer: Passed to :func:`deconv2d`.
        bias_constraint: Passed to :func:`deconv2d`.
        trainable: Passed to :func:`convdeconv2d2d`.

    Returns:
        tf.Tensor: The output tensor.

    See Also:
        :func:`resnet_general_block`
    """
    # check the input and infer the input shape
    if channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
        c_axis = -1
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
        c_axis = -3
    input = input_spec.validate('input', input)
    in_channels = get_static_shape(input)[c_axis]

    # check the functional arguments
    if use_bias is None:
        use_bias = normalizer_fn is None

    # derive the convolution function
    conv_fn = partial(
        deconv2d,
        channels_last=channels_last,
        weight_norm=weight_norm,
        kernel_initializer=kernel_initializer,
        kernel_regularizer=kernel_regularizer,
        kernel_constraint=kernel_constraint,
        use_bias=use_bias,
        bias_initializer=bias_initializer,
        bias_regularizer=bias_regularizer,
        bias_constraint=bias_constraint,
        trainable=trainable,
    )

    # build the resnet block
    return resnet_general_block(conv_fn,
                                input=input,
                                in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                strides=strides,
                                shortcut_kernel_size=shortcut_kernel_size,
                                resize_at_exit=resize_at_exit,
                                activation_fn=activation_fn,
                                normalizer_fn=normalizer_fn,
                                dropout_fn=dropout_fn,
                                name=name or 'resnet_deconv2d_block',
                                scope=scope)
Beispiel #5
0
def pixelcnn_2d_sample(fn,
                       inputs,
                       height,
                       width,
                       channels_last=True,
                       start=0,
                       end=None,
                       back_prop=False,
                       parallel_iterations=1,
                       swap_memory=False,
                       name=None):
    """
    Sample output from a PixelCNN 2D network, pixel-by-pixel.

    Args:
        fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`,
            the function to derive the outputs of PixelCNN 2D network at
            iteration `i`.  `inputs` are the pixel-by-pixel outputs gathered
            through iteration `0` to iteration `i - 1`.  The iteration index
            `i` may range from `0` to `height * width - 1`.
        inputs (Iterable[tf.Tensor]): The initial input tensors.
            All the tensors must be at least 4-d, with identical shape.
        height (int or tf.Tensor): The height of the outputs.
        width (int or tf.Tensor): The width of the outputs.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        start (int or tf.Tensor): The start iteration, default `0`.
        end (int or tf.Tensor): The end (exclusive) iteration.
            Default `height * width`.
        back_prop, parallel_iterations, swap_memory: Arguments passed to
            :func:`tf.while_loop`.

    Returns:
        tuple[tf.Tensor]: The final outputs.
    """
    from tfsnippet.layers.convolutional.utils import validate_conv2d_input

    # check the arguments
    def to_int(t):
        if is_tensor_object(t):
            return convert_to_tensor_and_cast(t, dtype=tf.int32)
        return int(t)

    height = to_int(height)
    width = to_int(width)

    inputs = list(inputs)
    if not inputs:
        raise ValueError('`inputs` must not be empty.')
    inputs[0], _, _ = validate_conv2d_input(inputs[0],
                                            channels_last=channels_last,
                                            arg_name='inputs[0]')
    input_spec = InputSpec(shape=get_static_shape(inputs[0]))
    for i, input in enumerate(inputs[1:], 1):
        inputs[i] = input_spec.validate('inputs[{}]'.format(i), input)

    # do pixelcnn sampling
    with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs):
        # the total size, start and end index
        total_size = height * width
        start = convert_to_tensor_and_cast(start, dtype=tf.int32)
        if end is None:
            end = convert_to_tensor_and_cast(total_size, dtype=tf.int32)
        else:
            end = convert_to_tensor_and_cast(end, dtype=tf.int32)

        # the mask shape
        if channels_last:
            mask_shape = [height, width, 1]
        else:
            mask_shape = [height, width]

        if any(is_tensor_object(t) for t in mask_shape):
            mask_shape = tf.stack(mask_shape, axis=0)

        # the input dynamic shape
        input_shape = get_shape(inputs[0])

        # the pixelcnn sampling loop
        def loop_cond(idx, _):
            return idx < end

        def loop_body(idx, inputs):
            inputs = tuple(inputs)

            # prepare for the output mask
            selector = tf.reshape(
                tf.concat([
                    tf.ones([idx], dtype=tf.uint8),
                    tf.zeros([1], dtype=tf.uint8),
                    tf.ones([total_size - idx - 1], dtype=tf.uint8)
                ],
                          axis=0), mask_shape)
            selector = tf.cast(broadcast_to_shape(selector, input_shape),
                               dtype=tf.bool)

            # obtain the outputs
            outputs = list(fn(idx, inputs))
            if len(outputs) != len(inputs):
                raise ValueError(
                    'The length of outputs != inputs: {} vs {}'.format(
                        len(outputs), len(inputs)))

            # mask the outputs
            for i, (input, output) in enumerate(zip(inputs, outputs)):
                input_dtype = inputs[i].dtype.base_dtype
                output_dtype = output.dtype.base_dtype
                if output_dtype != input_dtype:
                    raise TypeError(
                        '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: '
                        '{output} vs {input}'.format(idx=i,
                                                     output=output_dtype,
                                                     input=input_dtype))
                outputs[i] = tf.where(selector, input, output)

            return idx + 1, tuple(outputs)

        i0 = start
        _, outputs = tf.while_loop(
            cond=loop_cond,
            body=loop_body,
            loop_vars=(i0, tuple(inputs)),
            back_prop=back_prop,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
        )
        return outputs