Exemplo n.º 1
0
def validate_conv2d_input(input, channels_last, arg_name='input'):
    """
    Validate the input for 2-d convolution.

    Args:
        input: The input tensor, must be at least 4-d.
        channels_last (bool): Whether or not the last dimension is the
            channels dimension? (i.e., `data_format` is "NHWC")
        arg_name (str): Name of the input argument.

    Returns:
        (tf.Tensor, int, str): The validated input tensor, the number of input
            channels, and the data format.
    """
    if channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
        channel_axis = -1
        data_format = 'NHWC'
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
        channel_axis = -3
        data_format = 'NCHW'
    input = input_spec.validate(arg_name, input)
    input_shape = get_static_shape(input)
    in_channels = input_shape[channel_axis]

    return input, in_channels, data_format
Exemplo n.º 2
0
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype
        assert (shape is not None and len(shape) >= self.x_value_ndims)

        # re-build the x input spec
        x_shape_spec = []
        if self.x_value_ndims > 0:
            x_shape_spec = list(shape)[-self.x_value_ndims:]
        if self.require_batch_dims:
            x_shape_spec = ['?'] + x_shape_spec
        x_shape_spec = ['...'] + x_shape_spec
        self._x_input_spec = InputSpec(shape=x_shape_spec, dtype=dtype)

        # infer the dynamic value shape of x, and store it for inverse transform
        x_value_shape = []
        if self.x_value_ndims > 0:
            x_value_shape = list(shape)[-self.x_value_ndims:]

        neg_one_count = 0
        for i, s in enumerate(x_value_shape):
            if s is None:
                if neg_one_count > 0:
                    x_value_shape = get_shape(input)
                    if self.x_value_ndims > 0:
                        x_value_shape = x_value_shape[-self.x_value_ndims:]
                    break
                else:
                    x_value_shape[i] = -1
                    neg_one_count += 1

        self._x_value_shape = x_value_shape

        # now infer the y value shape according to new info obtained from x
        y_value_shape = list(self._y_value_shape)
        if isinstance(x_value_shape, list) and -1 not in x_value_shape:
            x_value_size = int(np.prod(x_value_shape))
            y_value_size = int(np.prod([s for s in y_value_shape if s != -1]))

            if (-1 in y_value_shape and x_value_size % y_value_size != 0) or \
                    (-1 not in y_value_shape and x_value_size != y_value_size):
                raise ValueError(
                    'Cannot reshape the tail dimensions of `x` into `y`: '
                    'x value shape {!r}, y value shape {!r}.'.format(
                        x_value_shape, y_value_shape))

            if -1 in y_value_shape:
                y_value_shape[y_value_shape.index(-1)] = \
                    x_value_size // y_value_size

            assert (-1 not in y_value_shape)
            self._y_value_shape = tuple(y_value_shape)

        # re-build the y input spec
        y_shape_spec = list(y_value_shape)
        if self.require_batch_dims:
            y_shape_spec = ['?'] + y_shape_spec
        y_shape_spec = ['...'] + y_shape_spec
        self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)
Exemplo n.º 3
0
    def _build_input_spec(self, input):
        batch_ndims = int(self.require_batch_dims)
        dtype = input.dtype.base_dtype

        x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims)
        y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims)

        self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype)
        self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype)
Exemplo n.º 4
0
    def _build(self, input=None):
        # check the input.
        input = tf.convert_to_tensor(input)
        dtype = input.dtype.base_dtype
        shape = get_static_shape(input)

        # These facts should have been checked in `BaseFlow.build`.
        assert (shape is not None)
        assert (len(shape) >= self.value_ndims)

        # compute var spec and input spec
        min_axis = min(self.axis)
        shape_spec = [None] * len(shape)
        for a in self.axis:
            shape_spec[a] = shape[a]
        shape_spec = shape_spec[min_axis:]
        assert (not not shape_spec)
        assert (self.value_ndims >= len(shape_spec))

        self._y_input_spec = self._x_input_spec = InputSpec(
            shape=(('...', ) + ('?', ) * (self.value_ndims - len(shape_spec)) +
                   tuple(shape_spec)),
            dtype=dtype)
        # the shape of variables must only have necessary dimensions,
        # such that we can switch freely between `channels_last = True`
        # (in which case `input.shape = (..., *,)`, and `channels_last = False`
        # (in which case `input.shape = (..., *, 1, 1)`.
        self._var_shape = tuple(s for s in shape_spec if s is not None)
        # and we still need to compute the aligned variable shape, such that
        # we can immediately reshape the variables into this aligned shape,
        # then compute `scale * input + bias`.
        self._var_shape_aligned = tuple(s or 1 for s in shape_spec)
        self._var_spec = ParamSpec(self._var_shape)

        # validate the input
        self._x_input_spec.validate('input', input)

        # build the variables
        self._bias = model_variable('bias',
                                    dtype=dtype,
                                    shape=self._var_shape,
                                    regularizer=self._bias_regularizer,
                                    constraint=self._bias_constraint,
                                    trainable=self._trainable)
        if self._scale_type == 'exp':
            self._pre_scale = model_variable(
                'log_scale',
                dtype=dtype,
                shape=self._var_shape,
                regularizer=self._log_scale_regularizer,
                constraint=self._log_scale_constraint,
                trainable=self._trainable)
        else:
            self._pre_scale = model_variable(
                'scale',
                dtype=dtype,
                shape=self._var_shape,
                regularizer=self._scale_regularizer,
                constraint=self._scale_constraint,
                trainable=self._trainable)
Exemplo n.º 5
0
    def _build_input_spec(self, input):
        super(FeatureMappingFlow, self)._build_input_spec(input)

        dtype = input.dtype.base_dtype
        shape = get_static_shape(input)

        # These facts should have been checked in `BaseFlow.build`.
        assert (shape is not None)
        assert (len(shape) >= self.value_ndims)

        # validate the feature axis, ensure it is covered by `value_ndims`.
        axis = self._axis
        axis_is_int = is_integer(axis)
        if axis_is_int:
            axis = [axis]
        else:
            axis = list(axis)

        for i, a in enumerate(axis):
            if a < 0:
                a += len(shape)
            if a < 0 or a < len(shape) - self.value_ndims:
                raise ValueError('`axis` out of range, or not covered by '
                                 '`value_ndims`: axis {}, value_ndims {}, '
                                 'input {}'.format(self._axis,
                                                   self.value_ndims, input))
            if shape[a] is None:
                raise ValueError('The feature axis of `input` is not '
                                 'deterministic: input {}, axis {}'.format(
                                     input, self._axis))

            # Store the negative axis, such that when new inputs can have more
            # dimensions than this `input`, the axis can still be correctly
            # resolved.
            axis[i] = a - len(shape)

        if axis_is_int:
            assert (len(axis) == 1)
            self._axis = axis[0]
        else:
            axis_len = len(axis)
            axis = tuple(sorted(set(axis)))
            if len(axis) != axis_len:
                raise ValueError(
                    'Duplicated elements after resolving negative '
                    '`axis` with respect to the `input`: '
                    'input {}, axis {}'.format(input, self._axis))
            self._axis = tuple(axis)

        # re-build the input spec
        batch_ndims = int(self.require_batch_dims)
        shape_spec = ['...'] + ['?'] * (self.value_ndims + batch_ndims)
        for a in axis:
            shape_spec[a] = shape[a]
        self._y_input_spec = self._x_input_spec = InputSpec(shape=shape_spec,
                                                            dtype=dtype)
        self._x_input_spec.validate('input', input)
Exemplo n.º 6
0
def transpose_conv2d_axis(input,
                          from_channels_last,
                          to_channels_last,
                          name=None):
    """
    Ensure the channels axis of `input` tensor to be placed at the desired axis.

    Args:
        input (tf.Tensor): The input tensor, at least 4-d.
        from_channels_last (bool): Whether or not the channels axis
            is the last axis in `input`? (i.e., the data format is "NHWC")
        to_channels_last (bool): Whether or not the channels axis
            should be the last axis in the output tensor?

    Returns:
        tf.Tensor: The (maybe) transposed output tensor.
    """
    if from_channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
    input = input_spec.validate('input', input)
    input_shape = get_static_shape(input)
    sample_and_batch_axis = [i for i in range(len(input_shape) - 3)]

    # check whether or not axis should be transpose
    if from_channels_last and not to_channels_last:
        transpose_axis = [-1, -3, -2]
    elif not from_channels_last and to_channels_last:
        transpose_axis = [-2, -1, -3]
    else:
        transpose_axis = None

    # transpose the axis
    if transpose_axis is not None:
        transpose_axis = [i + len(input_shape) for i in transpose_axis]
        input = tf.transpose(input,
                             sample_and_batch_axis + transpose_axis,
                             name=name or 'transpose_conv2d_axis')

    return input
Exemplo n.º 7
0
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype

        # resolve the split axis
        split_axis = self._split_axis
        if split_axis < 0:
            split_axis += len(shape)
        if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims:
            raise ValueError(
                '`split_axis` out of range, or not covered by `x_value_ndims`: '
                'split_axis {}, x_value_ndims {}, input {}'.format(
                    self._split_axis, self.x_value_ndims, input))
        split_axis -= len(shape)

        err_msg = (
            'The split axis of `input` must be at least 2: input {}, axis {}.'.
            format(input, split_axis))
        if shape[split_axis] is not None:
            x_n_features = shape[split_axis]
            if x_n_features < 2:
                raise ValueError(err_msg)
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left
        else:
            x_n_features = get_shape(input)[split_axis]
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left

            with assert_deps(
                [tf.assert_greater_equal(x_n_features, 2,
                                         message=err_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    x_n_left = tf.identity(x_n_left)
                    x_n_right = tf.identity(x_n_right)

            x_n_features = None

        self._split_axis = split_axis
        self._x_n_left = x_n_left
        self._x_n_right = x_n_right

        # build the x spec
        shape_spec = ['?'] * self.x_value_ndims
        if x_n_features is not None:
            shape_spec[self._split_axis] = x_n_features
        assert (not self.require_batch_dims)
        shape_spec = ['...'] + shape_spec
        self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
Exemplo n.º 8
0
def softmax_classification_output(logits, name=None):
    """
    Get the most possible softmax classification output for each logit.

    Args:
        logits: The softmax logits.  Its last dimension will be treated
            as the softmax logits dimension, and will be reduced.

    Returns:
        tf.Tensor: tf.int32 tensor, the class label for each logit.
    """
    logits = InputSpec(shape=('...', '?', '?')).validate('logits', logits)
    with tf.name_scope(name,
                       default_name='softmax_classification_output',
                       values=[logits]):
        return tf.argmax(logits, axis=-1, output_type=tf.int32)
Exemplo n.º 9
0
def classification_accuracy(y_pred, y_true, name=None):
    """
    Compute the classification accuracy for `y_pred` and `y_true`.

    Args:
        y_pred: The predicted labels.
        y_true: The ground truth labels.  Its shape must match `y_pred`.

    Returns:
        tf.Tensor: The accuracy.
    """
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = InputSpec(shape=get_static_shape(y_pred)). \
        validate('y_true', y_true)
    with tf.name_scope(name,
                       default_name='classification_accuracy',
                       values=[y_pred, y_true]):
        return tf.reduce_mean(
            tf.cast(tf.equal(y_pred, y_true), dtype=tf.float32))
Exemplo n.º 10
0
class BaseFlow(BaseLayer):
    """
    The basic class for normalizing flows.

    A normalizing flow transforms a random variable `x` into `y` by an
    (implicitly) invertible mapping :math:`y = f(x)`, whose Jaccobian matrix
    determinant :math:`\\det \\frac{\\partial f(x)}{\\partial x} \\neq 0`, thus
    can derive :math:`\\log p(y)` from given :math:`\\log p(x)`.
    """
    @add_name_and_scope_arg_doc
    def __init__(self,
                 x_value_ndims,
                 y_value_ndims=None,
                 require_batch_dims=False,
                 name=None,
                 scope=None):
        """
        Construct a new :class:`BaseFlow`.

        Args:
            x_value_ndims (int): Number of value dimensions in `x`.
                `x.ndims - x_value_ndims == log_det.ndims`.
            y_value_ndims (int): Number of value dimensions in `y`.
                `y.ndims - y_value_ndims == log_det.ndims`.
                If not specified, use `x_value_ndims`.
            require_batch_dims (bool): If :obj:`True`, `x` are required
                to have at least `x_value_ndims + 1` dimensions, and `y`
                are required to have at least `y_value_ndims + 1` dimensions.

                If :obj:`False`, `x` are required to have at least
                `x_value_ndims` dimensions, and `y` are required to have
                at least `y_value_ndims` dimensions.
        """
        x_value_ndims = int(x_value_ndims)
        if y_value_ndims is None:
            y_value_ndims = x_value_ndims
        else:
            y_value_ndims = int(y_value_ndims)

        super(BaseFlow, self).__init__(name=name, scope=scope)
        self._x_value_ndims = x_value_ndims
        self._y_value_ndims = y_value_ndims
        self._require_batch_dims = bool(require_batch_dims)

        self._x_input_spec = None  # type: InputSpec
        self._y_input_spec = None  # type: InputSpec

    def invert(self):
        """
        Get the inverted flow from this flow.

        The :meth:`transform()` will become the :meth:`inverse_transform()`
        in the inverted flow, and the :meth:`inverse_transform()` will become
        the :meth:`transform()` in the inverted flow.

        If the current flow has not been initialized, it must be initialized
        via :meth:`inverse_transform()` in the new flow.

        Returns:
            tfsnippet.layers.InvertFlow: The inverted flow.
        """
        from .invert import InvertFlow
        return InvertFlow(self)

    @property
    def x_value_ndims(self):
        """
        Get the number of value dimensions in `x`.

        Returns:
            int: The number of value dimensions in `x`.
        """
        return self._x_value_ndims

    @property
    def y_value_ndims(self):
        """
        Get the number of value dimensions in `y`.

        Returns:
            int: The number of value dimensions in `y`.
        """
        return self._y_value_ndims

    @property
    def require_batch_dims(self):
        """Whether or not this flow requires batch dimensions."""
        return self._require_batch_dims

    @property
    def explicitly_invertible(self):
        """
        Whether or not this flow is explicitly invertible?

        If a flow is not explicitly invertible, then it only supports to
        transform `x` into `y`, and corresponding :math:`\\log p(x)` into
        :math:`\\log p(y)`.  It cannot compute :math:`\\log p(y)` directly
        without knowing `x`, nor can it transform `x` back into `y`.

        Returns:
            bool: A boolean indicating whether or not the flow is explicitly
                invertible.
        """
        raise NotImplementedError()

    def _build_input_spec(self, input):
        batch_ndims = int(self.require_batch_dims)
        dtype = input.dtype.base_dtype

        x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims)
        y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims)

        self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype)
        self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype)

    def build(self, input=None):
        # check the input.
        if input is None:
            raise ValueError('`input` is required to build {}.'.format(
                self.__class__.__name__))

        input = tf.convert_to_tensor(input)
        shape = get_static_shape(input)
        require_ndims = self.x_value_ndims + int(self.require_batch_dims)
        require_ndims_text = ('x_value_ndims + 1'
                              if self.require_batch_dims else 'x_value_ndims')

        if shape is None or len(shape) < require_ndims:
            raise ValueError('`x.ndims` must be known and >= `{}`: x '
                             '{} vs ndims `{}`.'.format(
                                 require_ndims_text, input, require_ndims))

        # build the input spec
        self._build_input_spec(input)

        # build the layer
        return super(BaseFlow, self).build(input)

    def _transform(self, x, compute_y, compute_log_det):
        raise NotImplementedError()

    def transform(self, x, compute_y=True, compute_log_det=True, name=None):
        """
        Transform `x` into `y`, and compute the log-determinant of `f` at `x`
        (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`).

        Args:
            x (Tensor): The samples of `x`.
            compute_y (bool): Whether or not to compute :math:`y = f(x)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_y` and `compute_log_det` are set
                to :obj:`False`.
        """
        if not compute_y and not compute_log_det:
            raise ValueError('At least one of `compute_y` and '
                             '`compute_log_det` should be True.')

        x = tf.convert_to_tensor(x)
        if not self._has_built:
            self.build(x)

        x = self._x_input_spec.validate('x', x)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'transform', self),
                           values=[x]):
            y, log_det = self._transform(x, compute_y, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=x,
                            value_ndims=self.x_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            return y, log_det

    def _inverse_transform(self, y, compute_x, compute_log_det):
        raise NotImplementedError()

    def inverse_transform(self,
                          y,
                          compute_x=True,
                          compute_log_det=True,
                          name=None):
        """
        Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at
        `y` (i.e.,
        :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`).

        Args:
            y (Tensor): The samples of `y`.
            compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_x` and `compute_log_det` are set
                to :obj:`False`.
            RuntimeError: If the flow is not explicitly invertible.
        """
        if not self.explicitly_invertible:
            raise RuntimeError(
                'The flow is not explicitly invertible: {!r}'.format(self))
        if not compute_x and not compute_log_det:
            raise ValueError('At least one of `compute_x` and '
                             '`compute_log_det` should be True.')
        if not self._has_built:
            raise RuntimeError('`inverse_transform` cannot be called before '
                               'the flow has been built; it can be built by '
                               'calling `build`, `apply` or `transform`: '
                               '{!r}'.format(self))

        y = tf.convert_to_tensor(y)
        y = self._y_input_spec.validate('y', y)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'inverse_transform', self),
                           values=[y]):
            x, log_det = self._inverse_transform(y, compute_x, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=y,
                            value_ndims=self.y_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            return x, log_det

    def _apply(self, x):
        y, _ = self.transform(x, compute_y=True, compute_log_det=False)
        return y
Exemplo n.º 11
0
def resnet_deconv2d_block(input,
                          out_channels,
                          kernel_size,
                          strides=(1, 1),
                          shortcut_kernel_size=(1, 1),
                          channels_last=True,
                          resize_at_exit=False,
                          activation_fn=None,
                          normalizer_fn=None,
                          weight_norm=False,
                          dropout_fn=None,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          kernel_constraint=None,
                          use_bias=None,
                          bias_initializer=tf.zeros_initializer(),
                          bias_regularizer=None,
                          bias_constraint=None,
                          trainable=True,
                          name=None,
                          scope=None):
    """
    2D deconvolutional ResNet block.

    Args:
        input (Tensor): The input tensor, at least 4-d.
        out_channels (int): The channel numbers of the output.
        kernel_size (int or tuple[int]): Kernel size over spatial dimensions,
            for "conv" and "conv_1" deconvolutional layers.
        strides (int or tuple[int]): Strides over spatial dimensions,
            for all three deconvolutional layers.
        shortcut_kernel_size (int or tuple[int]): Kernel size over spatial
            dimensions, for the "shortcut" deconvolutional layer.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        resize_at_exit (bool): See :func:`resnet_general_block`.
        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        weight_norm: Passed to :func:`deconv2d`.
        dropout_fn: The dropout function.
        kernel_initializer: Passed to :func:`deconv2d`.
        kernel_regularizer: Passed to :func:`deconv2d`.
        kernel_constraint: Passed to :func:`deconv2d`.
        use_bias: Whether or not to use `bias` in :func:`deconv2d`?
            If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        bias_initializer: Passed to :func:`deconv2d`.
        bias_regularizer: Passed to :func:`deconv2d`.
        bias_constraint: Passed to :func:`deconv2d`.
        trainable: Passed to :func:`convdeconv2d2d`.

    Returns:
        tf.Tensor: The output tensor.

    See Also:
        :func:`resnet_general_block`
    """
    # check the input and infer the input shape
    if channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
        c_axis = -1
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
        c_axis = -3
    input = input_spec.validate('input', input)
    in_channels = get_static_shape(input)[c_axis]

    # check the functional arguments
    if use_bias is None:
        use_bias = normalizer_fn is None

    # derive the convolution function
    conv_fn = partial(
        deconv2d,
        channels_last=channels_last,
        weight_norm=weight_norm,
        kernel_initializer=kernel_initializer,
        kernel_regularizer=kernel_regularizer,
        kernel_constraint=kernel_constraint,
        use_bias=use_bias,
        bias_initializer=bias_initializer,
        bias_regularizer=bias_regularizer,
        bias_constraint=bias_constraint,
        trainable=trainable,
    )

    # build the resnet block
    return resnet_general_block(conv_fn,
                                input=input,
                                in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                strides=strides,
                                shortcut_kernel_size=shortcut_kernel_size,
                                resize_at_exit=resize_at_exit,
                                activation_fn=activation_fn,
                                normalizer_fn=normalizer_fn,
                                dropout_fn=dropout_fn,
                                name=name or 'resnet_deconv2d_block',
                                scope=scope)
Exemplo n.º 12
0
def pixelcnn_conv2d_resnet(
        input,
        out_channels,
        conv_fn=conv2d,
        # the default values for the following two arguments
        # are from the PixelCNN++ paper.
        vertical_kernel_size=(2, 3),
        horizontal_kernel_size=(2, 2),
        strides=(1, 1),
        channels_last=True,
        use_shortcut_conv=None,
        shortcut_conv_fn=None,
        shortcut_kernel_size=(1, 1),
        activation_fn=None,
        normalizer_fn=None,
        dropout_fn=None,
        gated=False,
        gate_sigmoid_bias=2.,
        use_bias=None,
        name=None,
        scope=None,
        **kwargs):
    """
    PixelCNN 2D convolutional ResNet block.

    Args:
        input (PixelCNN2DOutput): The output from the previous PixelCNN layer.
        out_channels (int): The channel numbers of the output.
        conv_fn: The convolution function for "conv_0" and "conv_1"
            convolutional layers.  See :func:`resnet_general_block`.
        vertical_kernel_size (int or tuple[int]): Kernel size over spatial
            dimensions, for "conv_0" and "conv_1" convolutional layers
            in the PixelCNN vertical stack.
        horizontal_kernel_size (int or tuple[int]): Kernel size over spatial
            dimensions, for "conv_0" and "conv_1" convolutional layers
            in the PixelCNN horizontal stack.
        strides (int or tuple[int]): Strides over spatial dimensions,
            for "conv_0", "conv_1" and "shortcut" convolutional layers.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        use_shortcut_conv (True or None): If :obj:`True`, force to apply a
            linear convolution transformation on the shortcut path.
            If :obj:`None` (by default), only use shortcut if necessary.
        shortcut_conv_fn: The convolution function for the "shortcut"
            convolutional layer.  If not specified, use `conv_fn`.
        shortcut_kernel_size (int or tuple[int]): Kernel size over spatial
            dimensions, for the "shortcut" convolutional layer.
        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        dropout_fn: The dropout function.
        gated (bool): Whether or not to use gate on the output of "conv_1"?
            `conv_1_output = activation_fn(conv_1_output) * sigmoid(gate)`.
        gate_sigmoid_bias (Tensor): The bias added to `gate` before applying
            the `sigmoid` activation.
        use_bias (bool or None): Whether or not to use `bias` in "conv_0" and
            "conv_1"?  If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        \\**kwargs: Other named arguments passed to "conv_0", "conv_1" and
            "shortcut" convolutional layers.

    Returns:
        PixelCNN2DOutput: The PixelCNN layer output.
    """
    if not isinstance(input, PixelCNN2DOutput):
        raise TypeError('`input` is not an instance of `PixelCNN2DOutput`: '
                        'got {!r}'.format(input))

    vertical, in_channels, _ = validate_conv2d_input(input.vertical,
                                                     channels_last,
                                                     'input.vertical')
    horizontal = InputSpec(shape=get_static_shape(vertical),
                           dtype=vertical.dtype.base_dtype). \
        validate('input.horizontal', input.horizontal)

    if shortcut_conv_fn is None:
        shortcut_conv_fn = conv_fn

    # derive the convolution functions
    vertical_conv_fn = partial(shifted_conv2d,
                               conv_fn=conv_fn,
                               spatial_shift=(1, 0))
    horizon_conv_fn = partial(shifted_conv2d,
                              conv_fn=conv_fn,
                              spatial_shift=(1, 1))

    with tf.variable_scope(scope,
                           default_name=name or 'pixelcnn_conv2d_resnet'):
        # first, derive the vertical stack output
        vertical = resnet_general_block(
            conv_fn=vertical_conv_fn,
            input=vertical,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=vertical_kernel_size,
            strides=strides,
            channels_last=channels_last,
            use_shortcut_conv=use_shortcut_conv,
            shortcut_conv_fn=shortcut_conv_fn,
            shortcut_kernel_size=shortcut_kernel_size,
            resize_at_exit=False,  # always resize at conv_0
            after_conv_0=None,
            after_conv_1=None,
            activation_fn=activation_fn,
            normalizer_fn=normalizer_fn,
            dropout_fn=dropout_fn,
            gated=gated,
            gate_sigmoid_bias=gate_sigmoid_bias,
            use_bias=use_bias,
            scope='vertical',
            **kwargs)

        horizontal = resnet_general_block(
            conv_fn=horizon_conv_fn,
            input=horizontal,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=horizontal_kernel_size,
            strides=strides,
            channels_last=channels_last,
            use_shortcut_conv=use_shortcut_conv,
            shortcut_conv_fn=shortcut_conv_fn,
            shortcut_kernel_size=shortcut_kernel_size,
            resize_at_exit=False,  # always resize at conv_0
            after_conv_0=partial(pixelcnn_conv2d_resnet_after_conv_0,
                                 vertical=vertical,
                                 out_channels=out_channels,
                                 channels_last=channels_last,
                                 activation_fn=activation_fn,
                                 shortcut_conv_fn=shortcut_conv_fn,
                                 **kwargs),
            after_conv_1=None,
            activation_fn=activation_fn,
            normalizer_fn=normalizer_fn,
            dropout_fn=dropout_fn,
            gated=gated,
            gate_sigmoid_bias=gate_sigmoid_bias,
            use_bias=use_bias,
            scope='horizontal',
            **kwargs)

    return PixelCNN2DOutput(vertical=vertical, horizontal=horizontal)
Exemplo n.º 13
0
    def _transform(self, x, compute_y, compute_log_det):
        # split the input x
        x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right],
                          axis=self._split_axis)
        do_compute_y = compute_y or (self._y_n_left is None)

        # apply the left transformation
        y1, log_det1 = self._left.transform(x1,
                                            compute_y=do_compute_y,
                                            compute_log_det=compute_log_det)

        # apply the right transformation
        if self._right is not None:
            y2, log_det2 = self._right.transform(
                x2, compute_y=do_compute_y, compute_log_det=compute_log_det)
        else:
            y2, log_det2 = x2, None

        # check the outputs
        y1_shape = get_static_shape(y1)
        y2_shape = get_static_shape(y2)

        if len(y1_shape) != len(y2_shape):
            raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} '
                               'vs y_right {}'.format(y1, y2))

        # build the y spec if not built
        join_axis = self._join_axis

        if self._y_n_left is None:
            # resolve the join axis
            if join_axis is None:
                join_axis = self._split_axis
            if join_axis < 0:
                join_axis += len(y1_shape)

            if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims:
                raise ValueError(
                    '`join_axis` out of range, or not covered by `y_value_ndims'
                    '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}'
                    .format(self._split_axis, self.y_value_ndims, y1, y2))
            join_axis -= len(y1_shape)

            err_msg = (
                '`y_left.shape[join_axis] + y_right.shape[join_axis]` must '
                'be at least 2: y_left {}, y_right {} axis {}.'.format(
                    y1, y2, join_axis))
            y_n_left = y1_shape[join_axis]
            y_n_right = y2_shape[join_axis]
            if y_n_left is not None and y_n_right is not None:
                y_n_features = y_n_left + y_n_right
                assert (y_n_features >= 2)
            else:
                y_n_left = get_shape(y1)[join_axis]
                y_n_right = get_shape(y2)[join_axis]
                y_n_features = None
                with assert_deps([
                        tf.assert_greater_equal(y_n_left + y_n_right,
                                                2,
                                                message=err_msg)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        y_n_left = tf.identity(y_n_left)
                        y_n_right = tf.identity(y_n_right)

            self._join_axis = join_axis
            self._y_n_left = y_n_left
            self._y_n_right = y_n_right

            # build the y spec
            dtype = self._x_input_spec.dtype
            y_shape_spec = ['?'] * self.y_value_ndims
            if y_n_features is not None:
                y_shape_spec[self._split_axis] = y_n_features
            assert (not self.require_batch_dims)
            y_shape_spec = ['...'] + y_shape_spec
            self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)

        assert (join_axis is not None)

        # compute y
        y = None
        if compute_y:
            y = tf.concat([y1, y2], axis=join_axis)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = log_det1
            if log_det2 is not None:
                log_det = sum_log_det([log_det, log_det2])

        return y, log_det
Exemplo n.º 14
0
def pixelcnn_2d_sample(fn,
                       inputs,
                       height,
                       width,
                       channels_last=True,
                       start=0,
                       end=None,
                       back_prop=False,
                       parallel_iterations=1,
                       swap_memory=False,
                       name=None):
    """
    Sample output from a PixelCNN 2D network, pixel-by-pixel.

    Args:
        fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`,
            the function to derive the outputs of PixelCNN 2D network at
            iteration `i`.  `inputs` are the pixel-by-pixel outputs gathered
            through iteration `0` to iteration `i - 1`.  The iteration index
            `i` may range from `0` to `height * width - 1`.
        inputs (Iterable[tf.Tensor]): The initial input tensors.
            All the tensors must be at least 4-d, with identical shape.
        height (int or tf.Tensor): The height of the outputs.
        width (int or tf.Tensor): The width of the outputs.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        start (int or tf.Tensor): The start iteration, default `0`.
        end (int or tf.Tensor): The end (exclusive) iteration.
            Default `height * width`.
        back_prop, parallel_iterations, swap_memory: Arguments passed to
            :func:`tf.while_loop`.

    Returns:
        tuple[tf.Tensor]: The final outputs.
    """
    from tfsnippet.layers.convolutional.utils import validate_conv2d_input

    # check the arguments
    def to_int(t):
        if is_tensor_object(t):
            return convert_to_tensor_and_cast(t, dtype=tf.int32)
        return int(t)

    height = to_int(height)
    width = to_int(width)

    inputs = list(inputs)
    if not inputs:
        raise ValueError('`inputs` must not be empty.')
    inputs[0], _, _ = validate_conv2d_input(inputs[0],
                                            channels_last=channels_last,
                                            arg_name='inputs[0]')
    input_spec = InputSpec(shape=get_static_shape(inputs[0]))
    for i, input in enumerate(inputs[1:], 1):
        inputs[i] = input_spec.validate('inputs[{}]'.format(i), input)

    # do pixelcnn sampling
    with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs):
        # the total size, start and end index
        total_size = height * width
        start = convert_to_tensor_and_cast(start, dtype=tf.int32)
        if end is None:
            end = convert_to_tensor_and_cast(total_size, dtype=tf.int32)
        else:
            end = convert_to_tensor_and_cast(end, dtype=tf.int32)

        # the mask shape
        if channels_last:
            mask_shape = [height, width, 1]
        else:
            mask_shape = [height, width]

        if any(is_tensor_object(t) for t in mask_shape):
            mask_shape = tf.stack(mask_shape, axis=0)

        # the input dynamic shape
        input_shape = get_shape(inputs[0])

        # the pixelcnn sampling loop
        def loop_cond(idx, _):
            return idx < end

        def loop_body(idx, inputs):
            inputs = tuple(inputs)

            # prepare for the output mask
            selector = tf.reshape(
                tf.concat([
                    tf.ones([idx], dtype=tf.uint8),
                    tf.zeros([1], dtype=tf.uint8),
                    tf.ones([total_size - idx - 1], dtype=tf.uint8)
                ],
                          axis=0), mask_shape)
            selector = tf.cast(broadcast_to_shape(selector, input_shape),
                               dtype=tf.bool)

            # obtain the outputs
            outputs = list(fn(idx, inputs))
            if len(outputs) != len(inputs):
                raise ValueError(
                    'The length of outputs != inputs: {} vs {}'.format(
                        len(outputs), len(inputs)))

            # mask the outputs
            for i, (input, output) in enumerate(zip(inputs, outputs)):
                input_dtype = inputs[i].dtype.base_dtype
                output_dtype = output.dtype.base_dtype
                if output_dtype != input_dtype:
                    raise TypeError(
                        '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: '
                        '{output} vs {input}'.format(idx=i,
                                                     output=output_dtype,
                                                     input=input_dtype))
                outputs[i] = tf.where(selector, input, output)

            return idx + 1, tuple(outputs)

        i0 = start
        _, outputs = tf.while_loop(
            cond=loop_cond,
            body=loop_body,
            loop_vars=(i0, tuple(inputs)),
            back_prop=back_prop,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
        )
        return outputs
Exemplo n.º 15
0
def act_norm(input,
             axis=-1,
             initializing=False,
             scale_type='exp',
             bias_regularizer=None,
             bias_constraint=None,
             log_scale_regularizer=None,
             log_scale_constraint=None,
             scale_regularizer=None,
             scale_constraint=None,
             trainable=True,
             epsilon=1e-6,
             name=None,
             scope=None,
             value_ndims=None):
    """
    ActNorm proposed by (Kingma & Dhariwal, 2018).

    Examples::

        import tfsnippet as spt

        # apply act_norm on a dense layer
        x = spt.layers.dense(x, units, activation_fn=tf.nn.relu,
                             normalizer_fn=functools.partial(
                                 act_norm, initializing=initializing))

        # apply act_norm on a conv2d layer
        x = spt.layers.conv2d(x, out_channels, (3, 3),
                              channels_last=channels_last,
                              activation_fn=tf.nn.relu,
                              normalizer_fn=functools.partial(
                                  act_norm,
                                  axis=-1 if channels_last else -3,
                                  value_ndims=3,
                                  initializing=initializing,
                              ))

    Args:
        input (Tensor): The input tensor.
        axis (int or Iterable[int]): The axis to apply ActNorm.
            Dimensions not in `axis` will be averaged out when computing
            the mean of activations. Default `-1`, the last dimension.
            All items of the `axis` should be covered by `value_ndims`.
        initializing (bool): Whether or not to use the input `x` to initialize
            the layer parameters? (default :obj:`True`)
        scale_type: One of {"exp", "linear"}.
            If "exp", ``y = (x + bias) * tf.exp(log_scale)``.
            If "linear", ``y = (x + bias) * scale``.
            Default is "exp".
        bias_regularizer: The regularizer for `bias`.
        bias_constraint: The constraint for `bias`.
        log_scale_regularizer: The regularizer for `log_scale`.
        log_scale_constraint: The constraint for `log_scale`.
        scale_regularizer: The regularizer for `scale`.
        scale_constraint: The constraint for `scale`.
        trainable (bool): Whether or not the variables are trainable?
        epsilon: Small float to avoid dividing by zero or taking
            logarithm of zero.

    Returns:
        tf.Tensor: The output after the ActNorm has been applied.
    """
    input = InputSpec(shape=['...']).validate('input', input)
    rank = len(get_static_shape(input))
    axis = list(validate_int_tuple_arg('axis', axis))
    for i, a in enumerate(axis):
        if a >= 0:
            axis[i] = a - rank
    value_ndims = max(-a for a in axis)
    layer = ActNorm(
        axis=axis,
        value_ndims=value_ndims,
        initialized=not initializing,
        scale_type=scale_type,
        bias_regularizer=bias_regularizer,
        bias_constraint=bias_constraint,
        log_scale_regularizer=log_scale_regularizer,
        log_scale_constraint=log_scale_constraint,
        scale_regularizer=scale_regularizer,
        scale_constraint=scale_constraint,
        trainable=trainable,
        epsilon=epsilon,
        name=name,
        scope=scope
    )
    return layer.apply(input)