Пример #1
0
def broadcast_to_shape_strict(x, shape, name=None):
    """
    Broadcast `x` to match `shape`.

    This method requires `rank(x)` to be less than or equal to `len(shape)`.
    You may use :func:`broadcast_to_shape` instead, to allow the cases where
    ``rank(x) > len(shape)``.

    Args:
        x: A tensor.
        shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape.

    Returns:
        tf.Tensor: The broadcasted tensor.
    """
    # check the parameters
    x = tf.convert_to_tensor(x)
    x_shape = get_static_shape(x)
    ns_values = [x]
    if is_tensor_object(shape):
        shape = tf.convert_to_tensor(shape)
        ns_values.append(shape)
    else:
        shape = tuple(int(s) for s in shape)

    with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values):
        cannot_broadcast_msg = (
            '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'.
            format(x, shape))

        # assert ``rank(x) <= len(shape)``
        if isinstance(shape, tuple) and x_shape is not None:
            if len(x_shape) > len(shape):
                raise ValueError(cannot_broadcast_msg)
        elif isinstance(shape, tuple):
            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         len(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)
        else:
            with assert_deps(
                [assert_rank(shape, 1,
                             message=cannot_broadcast_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    shape = tf.identity(shape)

            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         tf.size(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)

        # do broadcast
        return broadcast_to_shape(x, shape)
Пример #2
0
def broadcast_log_det_against_input(log_det, input, value_ndims, name=None):
    """
    Broadcast the shape of `log_det` to match the shape of `input`.

    Args:
        log_det: Tensor, the log-determinant.
        input: Tensor, the input.
        value_ndims (int): The number of dimensions of each values sample.

    Returns:
        tf.Tensor: The broadcasted log-determinant.
    """
    log_det = tf.convert_to_tensor(log_det)
    input = tf.convert_to_tensor(input)
    value_ndims = int(value_ndims)

    with tf.name_scope(name or 'broadcast_log_det_to_input_shape',
                       values=[log_det, input]):
        shape = get_shape(input)
        if value_ndims > 0:
            err_msg = (
                'Cannot broadcast `log_det` against `input`: log_det is {}, '
                'input is {}, value_ndims is {}.'.format(
                    log_det, input, value_ndims))
            with assert_deps(
                [assert_rank_at_least(input, value_ndims, message=err_msg)]):
                shape = shape[:-value_ndims]

        return broadcast_to_shape_strict(log_det, shape)
Пример #3
0
 def multiply_branch():
     with assert_deps(assertions):
         ones_template = tf.ones(shape, dtype=x.dtype.base_dtype)
     try:
         return x * ones_template
     except ValueError:  # pragma: no cover
         raise ValueError(cannot_broadcast_msg)
Пример #4
0
    def transform(self, x, compute_y=True, compute_log_det=True, name=None):
        """
        Transform `x` into `y`, and compute the log-determinant of `f` at `x`
        (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`).

        Args:
            x (Tensor): The samples of `x`.
            compute_y (bool): Whether or not to compute :math:`y = f(x)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_y` and `compute_log_det` are set
                to :obj:`False`.
        """
        if not compute_y and not compute_log_det:
            raise ValueError('At least one of `compute_y` and '
                             '`compute_log_det` should be True.')

        x = tf.convert_to_tensor(x)
        if not self._has_built:
            self.build(x)

        x = self._x_input_spec.validate('x', x)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'transform', self),
                           values=[x]):
            y, log_det = self._transform(x, compute_y, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=x,
                            value_ndims=self.x_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            if y is not None:
                maybe_add_histogram(y, 'y')
                y = maybe_check_numerics(y, 'y')

            if log_det is not None:
                maybe_add_histogram(log_det, 'log_det')
                log_det = maybe_check_numerics(log_det, 'log_det')

            return y, log_det
Пример #5
0
 def _check_scale_or_shift_shape(self, name, tensor, x2):
     assert_op = assert_shape_equal(
         tensor,
         x2,
         message='`{}.shape` expected to be {}, but got {}'.format(
             name, get_static_shape(x2), get_static_shape(tensor)))
     with assert_deps([assert_op]) as asserted:
         if asserted:  # pragma: no cover
             tensor = tf.identity(tensor)
     return tensor
Пример #6
0
    def __init__(self, distribution, ndims):
        """
        Construct a new :class:`BatchToValueDistribution`.

        Args:
            distribution (Distribution): The source distribution.
            ndims (int): The last few `batch_ndims` to be converted
                into `value_ndims`.  Must be non-negative.
        """

        distribution = as_distribution(distribution)
        ndims = int(ndims)
        if ndims < 0:
            raise ValueError('`ndims` must be non-negative integers: '
                             'got {!r}'.format(ndims))

        with tf.name_scope('BatchToValueDistribution.init'):
            # get new batch shape
            batch_shape = distribution.batch_shape
            batch_static_shape = distribution.get_batch_shape()
            if ndims > 0:
                # static shape
                if batch_static_shape.ndims < ndims:
                    raise ValueError(
                        '`distribution.batch_shape.ndims` is less then `ndims`'
                        ': distribution {}, batch_shape.ndims {}, ndims {}'.
                        format(distribution, batch_static_shape.ndims, ndims))
                batch_static_shape = batch_static_shape[:-ndims]

                # dynamic shape
                batch_shape = batch_shape[:-ndims]
                with assert_deps([
                        tf.assert_greater_equal(
                            tf.size(distribution.batch_shape), ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        batch_shape = tf.identity(batch_shape)

            # get new value ndims
            value_ndims = ndims + distribution.value_ndims

        self._distribution = distribution
        self._ndims = ndims

        super(BatchToValueDistribution, self).__init__(
            dtype=distribution.dtype,
            is_continuous=distribution.is_continuous,
            is_reparameterized=distribution.is_reparameterized,
            batch_shape=batch_shape,
            batch_static_shape=batch_static_shape,
            value_ndims=value_ndims,
        )
Пример #7
0
 def assert_batch_shape(c, batch_shape):
     c_batch_shape = c.batch_shape
     with assert_deps([
             tf.assert_equal(
                 tf.reduce_all(
                     tf.equal(
                         tf.concat([batch_shape, c_batch_shape], 0),
                         tf.concat([c_batch_shape, batch_shape], 0))),
                 True)
     ]) as asserted:
         if asserted:  # pragma: no cover
             batch_shape = tf.identity(batch_shape)
     return batch_shape
Пример #8
0
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype

        # resolve the split axis
        split_axis = self._split_axis
        if split_axis < 0:
            split_axis += len(shape)
        if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims:
            raise ValueError(
                '`split_axis` out of range, or not covered by `x_value_ndims`: '
                'split_axis {}, x_value_ndims {}, input {}'.format(
                    self._split_axis, self.x_value_ndims, input))
        split_axis -= len(shape)

        err_msg = (
            'The split axis of `input` must be at least 2: input {}, axis {}.'.
            format(input, split_axis))
        if shape[split_axis] is not None:
            x_n_features = shape[split_axis]
            if x_n_features < 2:
                raise ValueError(err_msg)
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left
        else:
            x_n_features = get_shape(input)[split_axis]
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left

            with assert_deps(
                [tf.assert_greater_equal(x_n_features, 2,
                                         message=err_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    x_n_left = tf.identity(x_n_left)
                    x_n_right = tf.identity(x_n_right)

            x_n_features = None

        self._split_axis = split_axis
        self._x_n_left = x_n_left
        self._x_n_right = x_n_right

        # build the x spec
        shape_spec = ['?'] * self.x_value_ndims
        if x_n_features is not None:
            shape_spec[self._split_axis] = x_n_features
        assert (not self.require_batch_dims)
        shape_spec = ['...'] + shape_spec
        self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
Пример #9
0
    def inverse_transform(self,
                          y,
                          compute_x=True,
                          compute_log_det=True,
                          name=None):
        """
        Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at
        `y` (i.e.,
        :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`).

        Args:
            y (Tensor): The samples of `y`.
            compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`?
                Default :obj:`True`.
            compute_log_det (bool): Whether or not to compute the
                log-determinant?  Default :obj:`True`.
            name (str): If specified, will use this name as the TensorFlow
                operational name scope.

        Returns:
            (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant.
                The items in the returned tuple might be :obj:`None`
                if corresponding `compute_?` argument is set to :obj:`False`.

        Raises:
            RuntimeError: If both `compute_x` and `compute_log_det` are set
                to :obj:`False`.
            RuntimeError: If the flow is not explicitly invertible.
        """
        if not self.explicitly_invertible:
            raise RuntimeError(
                'The flow is not explicitly invertible: {!r}'.format(self))
        if not compute_x and not compute_log_det:
            raise ValueError('At least one of `compute_x` and '
                             '`compute_log_det` should be True.')
        if not self._has_built:
            raise RuntimeError('`inverse_transform` cannot be called before '
                               'the flow has been built; it can be built by '
                               'calling `build`, `apply` or `transform`: '
                               '{!r}'.format(self))

        y = tf.convert_to_tensor(y)
        y = self._y_input_spec.validate('y', y)

        with tf.name_scope(name,
                           default_name=get_default_scope_name(
                               'inverse_transform', self),
                           values=[y]):
            x, log_det = self._inverse_transform(y, compute_x, compute_log_det)

            if compute_log_det:
                with assert_deps([
                        assert_log_det_shape_matches_input(
                            log_det=log_det,
                            input=y,
                            value_ndims=self.y_value_ndims)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        log_det = tf.identity(log_det)

            return x, log_det
Пример #10
0
def deconv2d(input,
             out_channels,
             kernel_size,
             strides=(1, 1),
             padding='same',
             channels_last=True,
             output_shape=None,
             activation_fn=None,
             normalizer_fn=None,
             weight_norm=False,
             gated=False,
             gate_sigmoid_bias=2.,
             kernel=None,
             kernel_initializer=None,
             kernel_regularizer=None,
             kernel_constraint=None,
             use_bias=None,
             bias=None,
             bias_initializer=tf.zeros_initializer(),
             bias_regularizer=None,
             bias_constraint=None,
             trainable=True,
             name=None,
             scope=None):
    """
    2D deconvolutional layer.

    Args:
        input (Tensor): The input tensor, at least 4-d.
        out_channels (int): The channel numbers of the deconvolution output.
        kernel_size (int or (int, int)): Kernel size over spatial dimensions.
        strides (int or (int, int)): Strides over spatial dimensions.
        padding: One of {"valid", "same"}, case in-sensitive.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        output_shape: If specified, use this as the shape of the
            deconvolution output; otherwise compute the size of each dimension
            by::

                output_size = input_size * strides
                if padding == 'valid':
                    output_size += max(kernel_size - strides, 0)

        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        weight_norm (bool or (tf.Tensor) -> tf.Tensor)):
            If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on
            `kernel`.  `use_scale` will be :obj:`True` if `normalizer_fn`
            is not specified, and :obj:`False` otherwise.  The axis reduction
            will be determined by the layer.

            If it is a callable function, then it will be used to normalize
            the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`.
            The user must ensure the axis reduction is correct by themselves.
        gated (bool): Whether or not to use gate on output?
            `output = activation_fn(output) * sigmoid(gate)`.
        gate_sigmoid_bias (Tensor): The bias added to `gate` before applying
            the `sigmoid` activation.
        kernel (Tensor): Instead of creating a new variable, use this tensor.
        kernel_initializer: The initializer for `kernel`.
            Would be ``default_kernel_initializer(...)`` if not specified.
        kernel_regularizer: The regularizer for `kernel`.
        kernel_constraint: The constraint for `kernel`.
        use_bias (bool or None): Whether or not to use `bias`?
            If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        bias (Tensor): Instead of creating a new variable, use this tensor.
        bias_initializer: The initializer for `bias`.
        bias_regularizer: The regularizer for `bias`.
        bias_constraint: The constraint for `bias`.
        trainable (bool): Whether or not the parameters are trainable?

    Returns:
        tf.Tensor: The output tensor.
    """
    input, in_channels, data_format = \
        validate_conv2d_input(input, channels_last)
    out_channels = validate_positive_int_arg('out_channels', out_channels)
    dtype = input.dtype.base_dtype
    if gated:
        out_channels *= 2

    # check functional arguments
    padding = validate_enum_arg('padding',
                                str(padding).upper(), ['VALID', 'SAME'])
    strides = validate_conv2d_strides_tuple('strides', strides, channels_last)

    weight_norm_fn = validate_weight_norm_arg(weight_norm,
                                              axis=-1,
                                              use_scale=normalizer_fn is None)
    if use_bias is None:
        use_bias = normalizer_fn is None

    # get the specification of outputs and parameters
    kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size)
    kernel_shape = kernel_size + (out_channels, in_channels)
    bias_shape = (out_channels, )

    given_h, given_w = None, None
    given_output_shape = output_shape

    if is_tensor_object(given_output_shape):
        given_output_shape = tf.convert_to_tensor(given_output_shape)
    elif given_output_shape is not None:
        given_h, given_w = given_output_shape

    # validate the parameters
    if kernel is not None:
        kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype)
        kernel = kernel_spec.validate('kernel', kernel)
    if kernel_initializer is None:
        kernel_initializer = default_kernel_initializer(weight_norm)
    if bias is not None:
        bias_spec = ParamSpec(shape=bias_shape, dtype=dtype)
        bias = bias_spec.validate('bias', bias)

    # the main part of the conv2d layer
    with tf.variable_scope(scope, default_name=name or 'deconv2d'):
        with tf.name_scope('output_shape'):
            # detect the input shape and axis arrangements
            input_shape = get_static_shape(input)
            if channels_last:
                c_axis, h_axis, w_axis = -1, -3, -2
            else:
                c_axis, h_axis, w_axis = -3, -2, -1

            output_shape = [None, None, None, None]
            output_shape[c_axis] = out_channels
            if given_output_shape is None:
                if input_shape[h_axis] is not None:
                    output_shape[h_axis] = get_deconv_output_length(
                        input_shape[h_axis], kernel_shape[0], strides[h_axis],
                        padding)
                if input_shape[w_axis] is not None:
                    output_shape[w_axis] = get_deconv_output_length(
                        input_shape[w_axis], kernel_shape[1], strides[w_axis],
                        padding)
            else:
                if not is_tensor_object(given_output_shape):
                    output_shape[h_axis] = given_h
                    output_shape[w_axis] = given_w

            # infer the batch shape in 4-d
            batch_shape = input_shape[:-3]
            if None not in batch_shape:
                output_shape[0] = int(np.prod(batch_shape))

            # now the static output shape is ready
            output_static_shape = tf.TensorShape(output_shape)

            # prepare for the dynamic batch shape
            if output_shape[0] is None:
                output_shape[0] = tf.reduce_prod(get_shape(input)[:-3])

            # prepare for the dynamic spatial dimensions
            if output_shape[h_axis] is None or output_shape[w_axis] is None:
                if given_output_shape is None:
                    input_shape = get_shape(input)
                    if output_shape[h_axis] is None:
                        output_shape[h_axis] = get_deconv_output_length(
                            input_shape[h_axis], kernel_shape[0],
                            strides[h_axis], padding)
                    if output_shape[w_axis] is None:
                        output_shape[w_axis] = get_deconv_output_length(
                            input_shape[w_axis], kernel_shape[1],
                            strides[w_axis], padding)
                else:
                    assert (is_tensor_object(given_output_shape))
                    with assert_deps([
                            assert_rank(given_output_shape, 1),
                            assert_scalar_equal(tf.size(given_output_shape), 2)
                    ]):
                        output_shape[h_axis] = given_output_shape[0]
                        output_shape[w_axis] = given_output_shape[1]

            # compose the final dynamic shape
            if any(is_tensor_object(s) for s in output_shape):
                output_shape = tf.stack(output_shape)
            else:
                output_shape = tuple(output_shape)

        # create the variables
        if kernel is None:
            kernel = model_variable('kernel',
                                    shape=kernel_shape,
                                    dtype=dtype,
                                    initializer=kernel_initializer,
                                    regularizer=kernel_regularizer,
                                    constraint=kernel_constraint,
                                    trainable=trainable)

        if weight_norm_fn is not None:
            kernel = weight_norm_fn(kernel)

        maybe_add_histogram(kernel, 'kernel')
        kernel = maybe_check_numerics(kernel, 'kernel')

        if use_bias and bias is None:
            bias = model_variable('bias',
                                  shape=bias_shape,
                                  initializer=bias_initializer,
                                  regularizer=bias_regularizer,
                                  constraint=bias_constraint,
                                  trainable=trainable)
            maybe_add_histogram(bias, 'bias')
            bias = maybe_check_numerics(bias, 'bias')

        # flatten to 4d
        output, s1, s2 = flatten_to_ndims(input, 4)

        # do convolution or deconvolution
        output = tf.nn.conv2d_transpose(value=output,
                                        filter=kernel,
                                        output_shape=output_shape,
                                        strides=strides,
                                        padding=padding,
                                        data_format=data_format)
        if output_static_shape is not None:
            output.set_shape(output_static_shape)

        # add bias
        if use_bias:
            output = tf.nn.bias_add(output, bias, data_format=data_format)

        # apply the normalization function if specified
        if normalizer_fn is not None:
            output = normalizer_fn(output)

        # split into halves if gated
        if gated:
            output, gate = tf.split(output, 2, axis=c_axis)

        # apply the activation function if specified
        if activation_fn is not None:
            output = activation_fn(output)

        # apply the gate if required
        if gated:
            output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate')

        # unflatten back to original shape
        output = unflatten_from_ndims(output, s1, s2)

        maybe_add_histogram(output, 'output')
        output = maybe_check_numerics(output, 'output')

    return output
Пример #11
0
    def _transform(self, x, compute_y, compute_log_det):
        # split the input x
        x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right],
                          axis=self._split_axis)
        do_compute_y = compute_y or (self._y_n_left is None)

        # apply the left transformation
        y1, log_det1 = self._left.transform(x1,
                                            compute_y=do_compute_y,
                                            compute_log_det=compute_log_det)

        # apply the right transformation
        if self._right is not None:
            y2, log_det2 = self._right.transform(
                x2, compute_y=do_compute_y, compute_log_det=compute_log_det)
        else:
            y2, log_det2 = x2, None

        # check the outputs
        y1_shape = get_static_shape(y1)
        y2_shape = get_static_shape(y2)

        if len(y1_shape) != len(y2_shape):
            raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} '
                               'vs y_right {}'.format(y1, y2))

        # build the y spec if not built
        join_axis = self._join_axis

        if self._y_n_left is None:
            # resolve the join axis
            if join_axis is None:
                join_axis = self._split_axis
            if join_axis < 0:
                join_axis += len(y1_shape)

            if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims:
                raise ValueError(
                    '`join_axis` out of range, or not covered by `y_value_ndims'
                    '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}'
                    .format(self._split_axis, self.y_value_ndims, y1, y2))
            join_axis -= len(y1_shape)

            err_msg = (
                '`y_left.shape[join_axis] + y_right.shape[join_axis]` must '
                'be at least 2: y_left {}, y_right {} axis {}.'.format(
                    y1, y2, join_axis))
            y_n_left = y1_shape[join_axis]
            y_n_right = y2_shape[join_axis]
            if y_n_left is not None and y_n_right is not None:
                y_n_features = y_n_left + y_n_right
                assert (y_n_features >= 2)
            else:
                y_n_left = get_shape(y1)[join_axis]
                y_n_right = get_shape(y2)[join_axis]
                y_n_features = None
                with assert_deps([
                        tf.assert_greater_equal(y_n_left + y_n_right,
                                                2,
                                                message=err_msg)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        y_n_left = tf.identity(y_n_left)
                        y_n_right = tf.identity(y_n_right)

            self._join_axis = join_axis
            self._y_n_left = y_n_left
            self._y_n_right = y_n_right

            # build the y spec
            dtype = self._x_input_spec.dtype
            y_shape_spec = ['?'] * self.y_value_ndims
            if y_n_features is not None:
                y_shape_spec[self._split_axis] = y_n_features
            assert (not self.require_batch_dims)
            y_shape_spec = ['...'] + y_shape_spec
            self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)

        assert (join_axis is not None)

        # compute y
        y = None
        if compute_y:
            y = tf.concat([y1, y2], axis=join_axis)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = log_det1
            if log_det2 is not None:
                log_det = sum_log_det([log_det, log_det2])

        return y, log_det
Пример #12
0
def shift(input, shift, name=None):
    """
    Shift each axis of `input` according to `shift`, but keep identical size.
    The extra content will be discarded if shifted outside the original size.
    Zeros will be padded to the front or end of shifted axes.

    Args:
        input (Tensor): The tensor to be shifted.
        shift (Iterable[int]): The shift length for each axes.
            It must be equal to the rank of `input`.
            For each axis, if its corresponding shift < 0, then the
            `input` will be shifted to left by `-shift` at that axis.
            If its shift > 0, then the `input` will be shifted to right
            by `shift` at that axis.

    Returns:
        tf.Tensor: The output tensor.
    """
    shift = tuple(int(s) for s in shift)
    input = tf.convert_to_tensor(input)
    shape = get_static_shape(input)

    if shape is None:
        raise ValueError(
            'The rank of `shape` is required to be deterministic: '
            'got {}'.format(input))
    if len(shift) != len(shape):
        raise ValueError('The length of `shift` is required to equal the rank '
                         'of `input`: shift {} vs input {}'.format(
                             shift, input))

    # cache for the dynamic shape
    def get_dynamic_shape():
        if cached[0] is None:
            cached[0] = get_shape(input)
        return cached[0]

    cached = [None]

    # main routine
    with tf.name_scope(name, default_name='shift', values=[input]):
        # compute the slicing and padding arguments
        has_shift = False
        assert_ops = []
        slice_begin = []
        slice_size = []
        paddings = []
        err_msg = ('Cannot shift `input`: input {} vs shift {}'.format(
            input, shift))

        for i, (axis_shift, axis_size) in enumerate(zip(shift, shape)):
            # fast approach: shift is zero, no slicing at the axis
            if axis_shift == 0:
                slice_begin.append(0)
                slice_size.append(-1)
                paddings.append((0, 0))
                continue

            # slow approach: shift is not zero, should slice at the axis
            axis_shift_abs = abs(axis_shift)

            # we first check whether or not the axis size is big enough
            if axis_size is None:
                dynamic_axis_size = get_dynamic_shape()[i]
                assert_ops.append(
                    tf.assert_greater_equal(dynamic_axis_size,
                                            axis_shift_abs,
                                            message=err_msg))
            else:
                if axis_size < axis_shift_abs:
                    raise ValueError(err_msg)

            # next, we compose the slicing range
            if axis_shift < 0:  # shift to left
                slice_begin.append(-axis_shift)
                slice_size.append(-1)
                paddings.append((0, -axis_shift))

            else:  # shift to right
                slice_begin.append(0)
                if axis_size is None:
                    slice_size.append(get_dynamic_shape()[i] - axis_shift)
                else:
                    slice_size.append(axis_size - axis_shift)
                paddings.append((axis_shift, 0))

            # mark the flag to indicate that we've got any axis to shift
            has_shift = True

        if assert_ops:
            with assert_deps(assert_ops) as asserted:
                if asserted:
                    input = tf.identity(input)

        # no axis to shift, directly return the input
        if not has_shift:
            return input

        # do slicing and padding
        if any(is_tensor_object(s) for s in slice_size):
            slice_size = tf.stack(slice_size, axis=0)

        output = tf.slice(input, slice_begin, slice_size)
        output = tf.pad(output, paddings)

        return output
Пример #13
0
def reshape_tail(input, ndims, shape, name=None):
    """
    Reshape the tail (last) `ndims` into specified `shape`.

    Usage::

        x = tf.zeros([2, 3, 4, 5, 6])
        reshape_tail(x, 3, [-1])  # output: zeros([2, 3, 120])
        reshape_tail(x, 1, [3, 2])  # output: zeros([2, 3, 4, 5, 3, 2])

    Args:
        input (Tensor): The input tensor, at least `ndims` dimensions.
        ndims (int): To reshape this number of dimensions at tail.
        shape (Iterable[int] or tf.Tensor): The shape of the new tail.

    Returns:
        tf.Tensor: The reshaped tensor.
    """
    input = tf.convert_to_tensor(input)
    if not is_tensor_object(shape):
        shape = list(int(s) for s in shape)
        neg_one_count = 0
        for s in shape:
            if s <= 0:
                if s == -1:
                    if neg_one_count > 0:
                        raise ValueError('`shape` is not a valid shape: at '
                                         'most one `-1` can be specified.')
                    else:
                        neg_one_count += 1
                else:
                    raise ValueError('`shape` is not a valid shape: {} is '
                                     'not allowed.'.format(s))

    with tf.name_scope(name or 'reshape_tail', values=[input]):
        # assert the dimension
        with assert_deps([
                assert_rank_at_least(
                    input, ndims, message='rank(input) must be at least ndims')
        ]) as asserted:
            if asserted:  # pragma: no cover
                input = tf.identity(input)

        # compute the static shape
        static_input_shape = get_static_shape(input)
        static_output_shape = None

        if static_input_shape is not None:
            if ndims > 0:
                left_shape = static_input_shape[:-ndims]
                right_shape = static_input_shape[-ndims:]
            else:
                left_shape = static_input_shape
                right_shape = ()

            # attempt to resolve "-1" in `shape`
            if isinstance(shape, list):
                if None not in right_shape:
                    shape_size = int(np.prod([s for s in shape if s != -1]))
                    right_shape_size = int(np.prod(right_shape))

                    if (-1 not in shape and shape_size != right_shape_size) or \
                            (-1 in shape and right_shape_size % shape_size != 0):
                        raise ValueError(
                            'Cannot reshape the tail dimensions of '
                            '`input` into `shape`: input {!r}, ndims '
                            '{}, shape {}.'.format(input, ndims, shape))

                    if -1 in shape:
                        pos = shape.index(-1)
                        shape[pos] = right_shape_size // shape_size

                static_output_shape = left_shape + \
                    tuple(s if s != -1 else None for s in shape)

        static_output_shape = tf.TensorShape(static_output_shape)

        # compute the dynamic shape
        input_shape = get_shape(input)
        if ndims > 0:
            output_shape = concat_shapes([input_shape[:-ndims], shape])
        else:
            output_shape = concat_shapes([input_shape, shape])

        # do reshape
        output = tf.reshape(input, output_shape)
        output.set_shape(static_output_shape)
        return output
Пример #14
0
def broadcast_concat(x, y, axis, name=None):
    """
    Broadcast `x` and `y`, then concat them along `axis`.

    This method cannot deal with all possible situations yet.
    `x` and `y` must have known number of dimensions, and only the deterministic
    axes will be broadcasted.  You must ensure the non-deterministic axes are
    properly broadcasted by yourself.

    Args:
        x: The tensor `x`.
        y: The tensor `y`.
        axis: The axis to be concatenated.

    Returns:
        tf.Tensor: The broadcast and concatenated tensor.
    """
    x = tf.convert_to_tensor(x)
    y = tf.convert_to_tensor(y)

    # check the arguments
    x_static_shape = get_static_shape(x)
    if x_static_shape is None:
        raise ValueError('`x` with non-deterministic shape is not supported.')
    y_static_shape = get_static_shape(y)
    if y_static_shape is None:
        raise ValueError('`y` with non-deterministic shape is not supported.')

    x_rank = len(x_static_shape)
    y_rank = len(y_static_shape)
    out_ndims = max(x_rank, y_rank)
    min_axis = -out_ndims
    max_axis = out_ndims - 1

    if axis < min_axis or axis > max_axis:
        raise ValueError('Invalid axis: must >= {} and <= {}, got {}'.format(
            min_axis, max_axis, axis))
    if axis >= 0:
        axis = axis - out_ndims

    # compute the broadcast shape
    out_static_shape = [None] * out_ndims

    x_tile = [1] * out_ndims
    y_tile = [1] * out_ndims
    assertions = []

    dynamic_shape_cache = {}

    def get_dynamic_shape(t):
        if t not in dynamic_shape_cache:
            dynamic_shape_cache[t] = get_shape(t)
        return dynamic_shape_cache[t]

    def broadcast_axis(i, a, b, a_tile, b_tile, a_tensor, b_tensor):
        err_msg = ('`x` and `y` cannot be broadcast concat: {} vs {}'.format(
            x, y))

        # validate whether or not a == b or can be broadcasted
        if a is None and b is None:
            # both dynamic, must be equal
            a = get_dynamic_shape(a_tensor)[i]
            b = get_dynamic_shape(b_tensor)[i]
            assertions.append(tf.assert_equal(a, b, message=err_msg))

        elif a is not None and b is not None:
            # both static, check immediately
            if a != 1 and b != 1 and a != b:
                raise ValueError(err_msg)

            if a == 1:
                a_tile[i] = b
            elif b == 1:
                b_tile[i] = a

            out_static_shape[i] = max(a, b)

        elif a is None:
            # a dynamic, b can be 1 or equal to a
            a = get_dynamic_shape(a_tensor)[i]
            if b == 1:
                b_tile[i] = a
            else:
                assertions.append(tf.assert_equal(a, b, message=err_msg))
                out_static_shape[i] = b

        else:
            broadcast_axis(i, b, a, b_tile, a_tile, b_tensor, a_tensor)

    def maybe_prepend_dims(t, rank, name):
        if rank < out_ndims:
            t = prepend_dims(t, out_ndims - rank, name=name)
        return t

    def maybe_tile(t, tile, name):
        if any(s != 1 for s in tile):
            if any(is_tensor_object(s) for s in tile):
                tile = tf.stack(tile, axis=0)
            t = tf.tile(t, tile, name=name)
        return t

    with tf.name_scope(name, default_name='broadcast_concat', values=[x, y]):
        # infer the configurations
        for i in range(-1, -out_ndims - 1, -1):
            a = x_static_shape[i] if i >= -x_rank else 1
            b = y_static_shape[i] if i >= -y_rank else 1

            if i != axis:
                broadcast_axis(i, a, b, x_tile, y_tile, x, y)
            else:
                if a is not None and b is not None:
                    out_static_shape[i] = a + b

        # do broadcast
        x = maybe_tile(maybe_prepend_dims(x, x_rank, name='prepend_dims_to_x'),
                       x_tile,
                       name='tile_x')
        y = maybe_tile(maybe_prepend_dims(y, y_rank, name='prepend_dims_to_y'),
                       y_tile,
                       name='tile_y')

        with assert_deps(assertions) as asserted:
            if asserted:
                x = tf.identity(x)
                y = tf.identity(y)

        # do concat
        ret = tf.concat([x, y], axis=axis)
        ret.set_shape(tf.TensorShape(out_static_shape))
        return ret
Пример #15
0
 def identity_branch():
     with assert_deps(assertions) as asserted:
         if asserted:
             return tf.identity(x)
         else:  # pragma: no cover
             return x
Пример #16
0
def broadcast_to_shape(x, shape, name=None):
    """
    Broadcast `x` to match `shape`.

    If ``rank(x) > len(shape)``, only the tail dimensions will be broadcasted
    to match `shape`.

    Args:
        x: A tensor.
        shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape.

    Returns:
        tf.Tensor: The broadcasted tensor.
    """
    # check the parameters
    x = tf.convert_to_tensor(x)
    x_shape = get_static_shape(x)
    ns_values = [x]
    if is_tensor_object(shape):
        shape = tf.convert_to_tensor(shape)
        ns_values.append(shape)
    else:
        shape = tuple(int(s) for s in shape)

    with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values):
        cannot_broadcast_msg = (
            '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'.
            format(x, shape))

        # fast routine: shape is tuple[int] and x_shape is all known,
        # we can use reshape + tile to do the broadcast, which should be faster
        # than using ``x * ones(shape)``.
        if isinstance(shape, tuple) and x_shape is not None and \
                all(s is not None for s in x_shape):
            # reshape to have the same dimension
            if len(x_shape) < len(shape):
                x_shape = (1, ) * (len(shape) - len(x_shape)) + x_shape
                x = tf.reshape(x, x_shape)

            # tile to have the same shape
            tile = []
            i = -1
            while i > -len(shape) - 1:
                a, b = x_shape[i], shape[i]
                if a == 1 and b > 1:
                    tile.append(b)
                elif a != b:
                    raise ValueError(cannot_broadcast_msg)
                else:
                    tile.append(1)
                i -= 1
            tile = [1] * (len(x_shape) - len(shape)) + list(reversed(tile))
            if any(s > 1 for s in tile):
                x = tf.tile(x, tile)

            return x

        # slow routine: we may need ``x * ones(shape)`` to do the broadcast
        assertions = []
        post_assert_shape = False
        static_shape = tf.TensorShape(None)

        if isinstance(shape, tuple) and x_shape is not None:
            need_multiply_ones = False

            # it should always broadcast if len(x_shape) < len(shape)
            if len(x_shape) < len(shape):
                need_multiply_ones = True

            # check the consistency of x and shape
            static_shape_hint = []  # list to gather the static shape hint
            axis_to_check = []  # list to gather the axis to check
            i = -1
            while i >= -len(shape) and i >= -len(x_shape):
                a, b = x_shape[i], shape[i]
                if a is None:
                    axis_to_check.append(i)
                else:
                    if a != b:
                        if a == 1:
                            need_multiply_ones = True
                        else:
                            raise ValueError(cannot_broadcast_msg)
                static_shape_hint.append(b)
                i -= 1

            # compose the static shape hint
            if len(shape) < len(x_shape):
                static_shape = x_shape[:-len(shape)]
            elif len(shape) > len(x_shape):
                static_shape = shape[:-len(x_shape)]
            else:
                static_shape = ()
            static_shape = tf.TensorShape(static_shape +
                                          tuple(reversed(static_shape_hint)))

            # compose the assertion operations and the multiply flag
            if axis_to_check:
                need_multiply_flags = []
                x_dynamic_shape = tf.shape(x)

                for i in axis_to_check:
                    assertions.append(
                        tf.assert_equal(tf.logical_or(
                            tf.equal(x_dynamic_shape[i], shape[i]),
                            tf.equal(x_dynamic_shape[i], 1),
                        ),
                                        True,
                                        message=cannot_broadcast_msg))
                    if len(x_shape) >= len(shape):
                        need_multiply_flags.append(
                            tf.not_equal(x_dynamic_shape[i], shape[i]))

                if not need_multiply_ones:
                    need_multiply_ones = \
                        tf.reduce_any(tf.stack(need_multiply_flags))

        else:
            # we have no ideal about what `shape` is here, thus we need to
            # assert the shape after ``x * ones(shape)``.
            need_multiply_ones = True
            post_assert_shape = True

        # do broadcast if `x_shape` != `shape`
        def multiply_branch():
            with assert_deps(assertions):
                ones_template = tf.ones(shape, dtype=x.dtype.base_dtype)
            try:
                return x * ones_template
            except ValueError:  # pragma: no cover
                raise ValueError(cannot_broadcast_msg)

        def identity_branch():
            with assert_deps(assertions) as asserted:
                if asserted:
                    return tf.identity(x)
                else:  # pragma: no cover
                    return x

        t = smart_cond(need_multiply_ones, multiply_branch, identity_branch)
        t.set_shape(static_shape)

        if post_assert_shape:
            post_assert_op = tf.assert_equal(tf.reduce_all(
                tf.equal(tf.shape(t)[-tf.size(shape):], shape)),
                                             True,
                                             message=cannot_broadcast_msg)
            with assert_deps([post_assert_op]) as asserted:
                if asserted:
                    t = tf.identity(t)

        return t