コード例 #1
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
def prepend_dims(x, ndims=1, name=None):
    """
    Prepend `[1] * ndims` to the beginning of the shape of `x`.

    Args:
        x: The tensor `x`.
        ndims: Number of `1` to prepend.

    Returns:
        tf.Tensor: The tensor with prepended dimensions.
    """
    ndims = int(ndims)
    if ndims < 0:
        raise ValueError('`ndims` must be >= 0: got {}'.format(ndims))

    x = tf.convert_to_tensor(x)
    if ndims == 0:
        return x

    with tf.name_scope(name, default_name='prepend_dims', values=[x]):
        static_shape = get_static_shape(x)
        if static_shape is not None:
            static_shape = tf.TensorShape([1] * ndims + list(static_shape))

        dynamic_shape = concat_shapes([[1] * ndims, get_shape(x)])

        y = tf.reshape(x, dynamic_shape)
        if static_shape is not None:
            y.set_shape(static_shape)

        return y
コード例 #2
0
def broadcast_log_det_against_input(log_det, input, value_ndims, name=None):
    """
    Broadcast the shape of `log_det` to match the shape of `input`.

    Args:
        log_det: Tensor, the log-determinant.
        input: Tensor, the input.
        value_ndims (int): The number of dimensions of each values sample.

    Returns:
        tf.Tensor: The broadcasted log-determinant.
    """
    log_det = tf.convert_to_tensor(log_det)
    input = tf.convert_to_tensor(input)
    value_ndims = int(value_ndims)

    with tf.name_scope(name or 'broadcast_log_det_to_input_shape',
                       values=[log_det, input]):
        shape = get_shape(input)
        if value_ndims > 0:
            err_msg = (
                'Cannot broadcast `log_det` against `input`: log_det is {}, '
                'input is {}, value_ndims is {}.'.format(
                    log_det, input, value_ndims))
            with assert_deps(
                [assert_rank_at_least(input, value_ndims, message=err_msg)]):
                shape = shape[:-value_ndims]

        return broadcast_to_shape_strict(log_det, shape)
コード例 #3
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype
        assert (shape is not None and len(shape) >= self.x_value_ndims)

        # re-build the x input spec
        x_shape_spec = []
        if self.x_value_ndims > 0:
            x_shape_spec = list(shape)[-self.x_value_ndims:]
        if self.require_batch_dims:
            x_shape_spec = ['?'] + x_shape_spec
        x_shape_spec = ['...'] + x_shape_spec
        self._x_input_spec = InputSpec(shape=x_shape_spec, dtype=dtype)

        # infer the dynamic value shape of x, and store it for inverse transform
        x_value_shape = []
        if self.x_value_ndims > 0:
            x_value_shape = list(shape)[-self.x_value_ndims:]

        neg_one_count = 0
        for i, s in enumerate(x_value_shape):
            if s is None:
                if neg_one_count > 0:
                    x_value_shape = get_shape(input)
                    if self.x_value_ndims > 0:
                        x_value_shape = x_value_shape[-self.x_value_ndims:]
                    break
                else:
                    x_value_shape[i] = -1
                    neg_one_count += 1

        self._x_value_shape = x_value_shape

        # now infer the y value shape according to new info obtained from x
        y_value_shape = list(self._y_value_shape)
        if isinstance(x_value_shape, list) and -1 not in x_value_shape:
            x_value_size = int(np.prod(x_value_shape))
            y_value_size = int(np.prod([s for s in y_value_shape if s != -1]))

            if (-1 in y_value_shape and x_value_size % y_value_size != 0) or \
                    (-1 not in y_value_shape and x_value_size != y_value_size):
                raise ValueError(
                    'Cannot reshape the tail dimensions of `x` into `y`: '
                    'x value shape {!r}, y value shape {!r}.'.format(
                        x_value_shape, y_value_shape))

            if -1 in y_value_shape:
                y_value_shape[y_value_shape.index(-1)] = \
                    x_value_size // y_value_size

            assert (-1 not in y_value_shape)
            self._y_value_shape = tuple(y_value_shape)

        # re-build the y input spec
        y_shape_spec = list(y_value_shape)
        if self.require_batch_dims:
            y_shape_spec = ['?'] + y_shape_spec
        y_shape_spec = ['...'] + y_shape_spec
        self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)
コード例 #4
0
ファイル: branch.py プロジェクト: shliujing/tfsnippet
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype

        # resolve the split axis
        split_axis = self._split_axis
        if split_axis < 0:
            split_axis += len(shape)
        if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims:
            raise ValueError(
                '`split_axis` out of range, or not covered by `x_value_ndims`: '
                'split_axis {}, x_value_ndims {}, input {}'.format(
                    self._split_axis, self.x_value_ndims, input))
        split_axis -= len(shape)

        err_msg = (
            'The split axis of `input` must be at least 2: input {}, axis {}.'.
            format(input, split_axis))
        if shape[split_axis] is not None:
            x_n_features = shape[split_axis]
            if x_n_features < 2:
                raise ValueError(err_msg)
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left
        else:
            x_n_features = get_shape(input)[split_axis]
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left

            with assert_deps(
                [tf.assert_greater_equal(x_n_features, 2,
                                         message=err_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    x_n_left = tf.identity(x_n_left)
                    x_n_right = tf.identity(x_n_right)

            x_n_features = None

        self._split_axis = split_axis
        self._x_n_left = x_n_left
        self._x_n_right = x_n_right

        # build the x spec
        shape_spec = ['?'] * self.x_value_ndims
        if x_n_features is not None:
            shape_spec[self._split_axis] = x_n_features
        assert (not self.require_batch_dims)
        shape_spec = ['...'] + shape_spec
        self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
コード例 #5
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _transform(self, x, compute_y, compute_log_det):
        # compute y
        y = None
        if compute_y:
            y = space_to_depth(x,
                               block_size=self._block_size,
                               channels_last=self._channels_last)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = ZeroLogDet(shape=get_shape(x)[:-3],
                                 dtype=x.dtype.base_dtype)

        return y, log_det
コード例 #6
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _inverse_transform(self, y, compute_x, compute_log_det):
        # compute x
        x = None
        if compute_x:
            x = depth_to_space(y,
                               block_size=self._block_size,
                               channels_last=self._channels_last)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = ZeroLogDet(shape=get_shape(y)[:-3],
                                 dtype=y.dtype.base_dtype)

        return x, log_det
コード例 #7
0
    def sample(self,
               n_samples=None,
               group_ndims=0,
               is_reparameterized=None,
               compute_density=None,
               name=None):
        self._validate_sample_is_reparameterized_arg(is_reparameterized)

        #######################################################################
        # slow routine: generate the mixture by one_hot * stack([c.sample()]) #
        #######################################################################
        with tf.name_scope(name or 'Mixture.sample'):
            cat = self.categorical.sample(n_samples, group_ndims=0)
            mask = tf.one_hot(cat,
                              self.n_components,
                              dtype=self.dtype,
                              axis=-1)
            if self.value_ndims > 0:
                static_shape = (mask.get_shape().as_list() +
                                [1] * self.value_ndims)
                dynamic_shape = concat_shapes(
                    [get_shape(mask), [1] * self.value_ndims])
                mask = tf.reshape(mask, dynamic_shape)
                mask.set_shape(static_shape)
            mask = tf.stop_gradient(mask)

            # derive the mixture samples
            c_samples = [
                c.sample(n_samples, group_ndims=0) for c in self.components
            ]
            samples = tf.reduce_sum(
                mask * tf.stack(c_samples, axis=-self.value_ndims - 1),
                axis=-self.value_ndims - 1)

            if not self.is_reparameterized:
                samples = tf.stop_gradient(samples)

            t = StochasticTensor(distribution=self,
                                 tensor=samples,
                                 n_samples=n_samples,
                                 group_ndims=group_ndims,
                                 is_reparameterized=is_reparameterized)

            if compute_density:
                compute_density_immediately(t)

            return t
コード例 #8
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _inverse_transform(self, y, compute_x, compute_log_det):
        assert (len(get_static_shape(y)) >= self.y_value_ndims)

        # compute y
        x = None
        if compute_x:
            x = reshape_tail(y, self.y_value_ndims, self._x_value_shape)

        # compute log_det
        log_det = None
        if compute_log_det:
            dst_shape = get_shape(y)
            if self.y_value_ndims > 0:
                dst_shape = dst_shape[:-self.y_value_ndims]
            log_det = ZeroLogDet(dst_shape, dtype=y.dtype.base_dtype)

        return x, log_det
コード例 #9
0
ファイル: rearrangement.py プロジェクト: shliujing/tfsnippet
    def _transform_or_inverse_transform(self, x, compute_y, compute_log_det,
                                        permutation):
        assert (0 > self.axis >= -self.value_ndims >= -len(get_static_shape(x)))
        assert (get_static_shape(x)[self.axis] == self._n_features)

        # compute y
        y = None
        if compute_y:
            y = tf.gather(x, permutation, axis=self.axis)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = ZeroLogDet(get_shape(x)[:-self.value_ndims],
                                 x.dtype.base_dtype)

        return y, log_det
コード例 #10
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _transform(self, x, compute_y, compute_log_det):
        assert (len(get_static_shape(x)) >= self.x_value_ndims)

        # compute y
        y = None
        if compute_y:
            y = reshape_tail(x, self.x_value_ndims, self._y_value_shape)

        # compute log_det
        log_det = None
        if compute_log_det:
            dst_shape = get_shape(x)
            if self.x_value_ndims > 0:
                dst_shape = dst_shape[:-self.x_value_ndims]
            log_det = ZeroLogDet(dst_shape, dtype=x.dtype.base_dtype)

        return y, log_det
コード例 #11
0
ファイル: dropout_.py プロジェクト: shliujing/tfsnippet
def dropout(input, rate=.5, noise_shape=None, training=False, name=None):
    """
    Apply dropout on `input`.

    Args:
        input (Tensor): The input tensor.
        rate (float or tf.Tensor): The rate of dropout.
        noise_shape (tuple[int] or tf.Tensor): Shape of the noise.
            If not specified, use the shape of `input`.
        training (bool or tf.Tensor): Whether or not the model is under
            training stage?

    Returns:
        tf.Tensor: The dropout transformed tensor.
    """
    input = tf.convert_to_tensor(input)

    with tf.name_scope(name, default_name='dropout', values=[input]):
        dtype = input.dtype.base_dtype
        retain_prob = convert_to_tensor_and_cast(1. - rate, dtype=dtype)
        inv_retain_prob = 1. / retain_prob
        if noise_shape is None:
            noise_shape = get_shape(input)

        def training_branch():
            noise = tf.random_uniform(shape=noise_shape,
                                      minval=0.,
                                      maxval=1.,
                                      dtype=dtype)
            mask = tf.cast(noise < retain_prob, dtype=dtype)
            return input * mask * inv_retain_prob

        def testing_branch():
            return input

        return smart_cond(
            training,
            training_branch,
            testing_branch,
        )
コード例 #12
0
ファイル: pixelcnn.py プロジェクト: shliujing/tfsnippet
def pixelcnn_2d_input(input,
                      channels_last=True,
                      auxiliary_channel=True,
                      name=None):
    """
    Prepare the input for a PixelCNN 2D network (Tim Salimans, 2017).

    This method must be applied on the input once before any other PixelCNN
    2D layers, for example::

        input = ...  # the input x

        # prepare for the convolution stack
        output = spt.layers.pixelcnn_2d_input(input)

        # apply the PixelCNN 2D layers.
        for i in range(5):
            output = spt.layers.pixelcnn_conv2d_resnet(
                output,
                out_channels=64,
                vertical_kernel_size=(2, 3),
                horizontal_kernel_size=(2, 2),
                activation_fn=tf.nn.leaky_relu,
                normalizer_fn=spt.layers.batch_norm
            )

        # get the final output of the PixelCNN 2D network.
        output = pixelcnn_2d_output(output)

    Args:
        input (Tensor): The input tensor, at least 4-d.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        auxiliary_channel (bool): Whether or not to add a channel to `input`,
            with all elements set to `1`?

    Returns:
        PixelCNN2DOutput: The PixelCNN layer output.
    """
    input, in_channels, _ = validate_conv2d_input(input, channels_last)
    if channels_last:
        h_axis, w_axis, c_axis = -3, -2, -1
    else:
        c_axis, h_axis, w_axis = -3, -2, -1
    rank = len(get_static_shape(input))

    with tf.name_scope(name, default_name='pixelcnn_input', values=[input]):
        # add a channels with all `1`s
        if auxiliary_channel:
            ones_static_shape = [None] * rank
            ones_dynamic_shape = list(ones_static_shape)
            ones_dynamic_shape[c_axis] = 1

            if None in ones_dynamic_shape:
                x_dynamic_shape = get_shape(input)
                for i, s in enumerate(ones_dynamic_shape):
                    if s is None:
                        ones_dynamic_shape[i] = x_dynamic_shape[i]

            ones = tf.ones(shape=tf.stack(ones_dynamic_shape, axis=0),
                           dtype=input.dtype.base_dtype)
            ones.set_shape(tf.TensorShape(ones_static_shape))
            input = tf.concat([input, ones],
                              axis=c_axis,
                              name='auxiliary_input')

        # derive the vertical and horizontal convolution stacks
        down_shift = [0] * rank
        down_shift[h_axis] = 1
        right_shift = [0] * rank
        right_shift[w_axis] = 1

        return PixelCNN2DOutput(vertical=shift(input,
                                               shift=down_shift,
                                               name='vertical'),
                                horizontal=shift(input,
                                                 shift=right_shift,
                                                 name='horizontal'))
コード例 #13
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
 def get_dynamic_shape(t):
     if t not in dynamic_shape_cache:
         dynamic_shape_cache[t] = get_shape(t)
     return dynamic_shape_cache[t]
コード例 #14
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
def reshape_tail(input, ndims, shape, name=None):
    """
    Reshape the tail (last) `ndims` into specified `shape`.

    Usage::

        x = tf.zeros([2, 3, 4, 5, 6])
        reshape_tail(x, 3, [-1])  # output: zeros([2, 3, 120])
        reshape_tail(x, 1, [3, 2])  # output: zeros([2, 3, 4, 5, 3, 2])

    Args:
        input (Tensor): The input tensor, at least `ndims` dimensions.
        ndims (int): To reshape this number of dimensions at tail.
        shape (Iterable[int] or tf.Tensor): The shape of the new tail.

    Returns:
        tf.Tensor: The reshaped tensor.
    """
    input = tf.convert_to_tensor(input)
    if not is_tensor_object(shape):
        shape = list(int(s) for s in shape)
        neg_one_count = 0
        for s in shape:
            if s <= 0:
                if s == -1:
                    if neg_one_count > 0:
                        raise ValueError('`shape` is not a valid shape: at '
                                         'most one `-1` can be specified.')
                    else:
                        neg_one_count += 1
                else:
                    raise ValueError('`shape` is not a valid shape: {} is '
                                     'not allowed.'.format(s))

    with tf.name_scope(name or 'reshape_tail', values=[input]):
        # assert the dimension
        with assert_deps([
                assert_rank_at_least(
                    input, ndims, message='rank(input) must be at least ndims')
        ]) as asserted:
            if asserted:  # pragma: no cover
                input = tf.identity(input)

        # compute the static shape
        static_input_shape = get_static_shape(input)
        static_output_shape = None

        if static_input_shape is not None:
            if ndims > 0:
                left_shape = static_input_shape[:-ndims]
                right_shape = static_input_shape[-ndims:]
            else:
                left_shape = static_input_shape
                right_shape = ()

            # attempt to resolve "-1" in `shape`
            if isinstance(shape, list):
                if None not in right_shape:
                    shape_size = int(np.prod([s for s in shape if s != -1]))
                    right_shape_size = int(np.prod(right_shape))

                    if (-1 not in shape and shape_size != right_shape_size) or \
                            (-1 in shape and right_shape_size % shape_size != 0):
                        raise ValueError(
                            'Cannot reshape the tail dimensions of '
                            '`input` into `shape`: input {!r}, ndims '
                            '{}, shape {}.'.format(input, ndims, shape))

                    if -1 in shape:
                        pos = shape.index(-1)
                        shape[pos] = right_shape_size // shape_size

                static_output_shape = left_shape + \
                    tuple(s if s != -1 else None for s in shape)

        static_output_shape = tf.TensorShape(static_output_shape)

        # compute the dynamic shape
        input_shape = get_shape(input)
        if ndims > 0:
            output_shape = concat_shapes([input_shape[:-ndims], shape])
        else:
            output_shape = concat_shapes([input_shape, shape])

        # do reshape
        output = tf.reshape(input, output_shape)
        output.set_shape(static_output_shape)
        return output
コード例 #15
0
def is_log_det_shape_matches_input(log_det, input, value_ndims, name=None):
    """
    Check whether or not the shape of `log_det` matches the shape of `input`.

    Basically, the shapes of `log_det` and `input` should satisfy::

        if value_ndims > 0:
            assert(log_det.shape == input.shape[:-value_ndims])
        else:
            assert(log_det.shape == input.shape)

    Args:
        log_det: Tensor, the log-determinant.
        input: Tensor, the input.
        value_ndims (int): The number of dimensions of each values sample.

    Returns:
        bool or tf.Tensor: A boolean or a tensor, indicating whether or not
            the shape of `log_det` matches the shape of `input`.
    """
    if not is_tensor_object(log_det):
        log_det = tf.convert_to_tensor(log_det)
    if not is_tensor_object(input):
        input = tf.convert_to_tensor(input)
    value_ndims = int(value_ndims)

    with tf.name_scope(name or 'is_log_det_shape_matches_input'):
        log_det_shape = get_static_shape(log_det)
        input_shape = get_static_shape(input)

        # if both shapes have deterministic ndims, we can compare each axis
        # separately.
        if log_det_shape is not None and input_shape is not None:
            if len(log_det_shape) + value_ndims != len(input_shape):
                return False
            dynamic_axis = []

            for i, (a, b) in enumerate(zip(log_det_shape, input_shape)):
                if a is None or b is None:
                    dynamic_axis.append(i)
                elif a != b:
                    return False

            if not dynamic_axis:
                return True

            log_det_shape = get_shape(log_det)
            input_shape = get_shape(input)
            return tf.reduce_all([
                tf.equal(log_det_shape[i], input_shape[i])
                for i in dynamic_axis
            ])

        # otherwise we need to do a fully dynamic check, including check
        # ``log_det.ndims + value_ndims == input_shape.ndims``
        is_ndims_matches = tf.equal(
            tf.rank(log_det) + value_ndims, tf.rank(input))
        log_det_shape = get_shape(log_det)
        input_shape = get_shape(input)
        if value_ndims > 0:
            input_shape = input_shape[:-value_ndims]

        return tf.cond(
            is_ndims_matches,
            lambda: tf.reduce_all(
                tf.equal(
                    # The following trick ensures we're comparing two tensors
                    # with the same shape, such as to avoid some potential issues
                    # about the cond operation.
                    tf.concat([log_det_shape, input_shape], 0),
                    tf.concat([input_shape, log_det_shape], 0),
                )),
            lambda: tf.constant(False, dtype=tf.bool))
コード例 #16
0
def pixelcnn_2d_sample(fn,
                       inputs,
                       height,
                       width,
                       channels_last=True,
                       start=0,
                       end=None,
                       back_prop=False,
                       parallel_iterations=1,
                       swap_memory=False,
                       name=None):
    """
    Sample output from a PixelCNN 2D network, pixel-by-pixel.

    Args:
        fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`,
            the function to derive the outputs of PixelCNN 2D network at
            iteration `i`.  `inputs` are the pixel-by-pixel outputs gathered
            through iteration `0` to iteration `i - 1`.  The iteration index
            `i` may range from `0` to `height * width - 1`.
        inputs (Iterable[tf.Tensor]): The initial input tensors.
            All the tensors must be at least 4-d, with identical shape.
        height (int or tf.Tensor): The height of the outputs.
        width (int or tf.Tensor): The width of the outputs.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        start (int or tf.Tensor): The start iteration, default `0`.
        end (int or tf.Tensor): The end (exclusive) iteration.
            Default `height * width`.
        back_prop, parallel_iterations, swap_memory: Arguments passed to
            :func:`tf.while_loop`.

    Returns:
        tuple[tf.Tensor]: The final outputs.
    """
    from tfsnippet.layers.convolutional.utils import validate_conv2d_input

    # check the arguments
    def to_int(t):
        if is_tensor_object(t):
            return convert_to_tensor_and_cast(t, dtype=tf.int32)
        return int(t)

    height = to_int(height)
    width = to_int(width)

    inputs = list(inputs)
    if not inputs:
        raise ValueError('`inputs` must not be empty.')
    inputs[0], _, _ = validate_conv2d_input(inputs[0],
                                            channels_last=channels_last,
                                            arg_name='inputs[0]')
    input_spec = InputSpec(shape=get_static_shape(inputs[0]))
    for i, input in enumerate(inputs[1:], 1):
        inputs[i] = input_spec.validate('inputs[{}]'.format(i), input)

    # do pixelcnn sampling
    with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs):
        # the total size, start and end index
        total_size = height * width
        start = convert_to_tensor_and_cast(start, dtype=tf.int32)
        if end is None:
            end = convert_to_tensor_and_cast(total_size, dtype=tf.int32)
        else:
            end = convert_to_tensor_and_cast(end, dtype=tf.int32)

        # the mask shape
        if channels_last:
            mask_shape = [height, width, 1]
        else:
            mask_shape = [height, width]

        if any(is_tensor_object(t) for t in mask_shape):
            mask_shape = tf.stack(mask_shape, axis=0)

        # the input dynamic shape
        input_shape = get_shape(inputs[0])

        # the pixelcnn sampling loop
        def loop_cond(idx, _):
            return idx < end

        def loop_body(idx, inputs):
            inputs = tuple(inputs)

            # prepare for the output mask
            selector = tf.reshape(
                tf.concat([
                    tf.ones([idx], dtype=tf.uint8),
                    tf.zeros([1], dtype=tf.uint8),
                    tf.ones([total_size - idx - 1], dtype=tf.uint8)
                ],
                          axis=0), mask_shape)
            selector = tf.cast(broadcast_to_shape(selector, input_shape),
                               dtype=tf.bool)

            # obtain the outputs
            outputs = list(fn(idx, inputs))
            if len(outputs) != len(inputs):
                raise ValueError(
                    'The length of outputs != inputs: {} vs {}'.format(
                        len(outputs), len(inputs)))

            # mask the outputs
            for i, (input, output) in enumerate(zip(inputs, outputs)):
                input_dtype = inputs[i].dtype.base_dtype
                output_dtype = output.dtype.base_dtype
                if output_dtype != input_dtype:
                    raise TypeError(
                        '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: '
                        '{output} vs {input}'.format(idx=i,
                                                     output=output_dtype,
                                                     input=input_dtype))
                outputs[i] = tf.where(selector, input, output)

            return idx + 1, tuple(outputs)

        i0 = start
        _, outputs = tf.while_loop(
            cond=loop_cond,
            body=loop_body,
            loop_vars=(i0, tuple(inputs)),
            back_prop=back_prop,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
        )
        return outputs
コード例 #17
0
ファイル: conv2d_.py プロジェクト: shliujing/tfsnippet
def deconv2d(input,
             out_channels,
             kernel_size,
             strides=(1, 1),
             padding='same',
             channels_last=True,
             output_shape=None,
             activation_fn=None,
             normalizer_fn=None,
             weight_norm=False,
             gated=False,
             gate_sigmoid_bias=2.,
             kernel=None,
             kernel_initializer=None,
             kernel_regularizer=None,
             kernel_constraint=None,
             use_bias=None,
             bias=None,
             bias_initializer=tf.zeros_initializer(),
             bias_regularizer=None,
             bias_constraint=None,
             trainable=True,
             name=None,
             scope=None):
    """
    2D deconvolutional layer.

    Args:
        input (Tensor): The input tensor, at least 4-d.
        out_channels (int): The channel numbers of the deconvolution output.
        kernel_size (int or (int, int)): Kernel size over spatial dimensions.
        strides (int or (int, int)): Strides over spatial dimensions.
        padding: One of {"valid", "same"}, case in-sensitive.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        output_shape: If specified, use this as the shape of the
            deconvolution output; otherwise compute the size of each dimension
            by::

                output_size = input_size * strides
                if padding == 'valid':
                    output_size += max(kernel_size - strides, 0)

        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        weight_norm (bool or (tf.Tensor) -> tf.Tensor)):
            If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on
            `kernel`.  `use_scale` will be :obj:`True` if `normalizer_fn`
            is not specified, and :obj:`False` otherwise.  The axis reduction
            will be determined by the layer.

            If it is a callable function, then it will be used to normalize
            the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`.
            The user must ensure the axis reduction is correct by themselves.
        gated (bool): Whether or not to use gate on output?
            `output = activation_fn(output) * sigmoid(gate)`.
        gate_sigmoid_bias (Tensor): The bias added to `gate` before applying
            the `sigmoid` activation.
        kernel (Tensor): Instead of creating a new variable, use this tensor.
        kernel_initializer: The initializer for `kernel`.
            Would be ``default_kernel_initializer(...)`` if not specified.
        kernel_regularizer: The regularizer for `kernel`.
        kernel_constraint: The constraint for `kernel`.
        use_bias (bool or None): Whether or not to use `bias`?
            If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        bias (Tensor): Instead of creating a new variable, use this tensor.
        bias_initializer: The initializer for `bias`.
        bias_regularizer: The regularizer for `bias`.
        bias_constraint: The constraint for `bias`.
        trainable (bool): Whether or not the parameters are trainable?

    Returns:
        tf.Tensor: The output tensor.
    """
    input, in_channels, data_format = \
        validate_conv2d_input(input, channels_last)
    out_channels = validate_positive_int_arg('out_channels', out_channels)
    dtype = input.dtype.base_dtype
    if gated:
        out_channels *= 2

    # check functional arguments
    padding = validate_enum_arg('padding',
                                str(padding).upper(), ['VALID', 'SAME'])
    strides = validate_conv2d_strides_tuple('strides', strides, channels_last)

    weight_norm_fn = validate_weight_norm_arg(weight_norm,
                                              axis=-1,
                                              use_scale=normalizer_fn is None)
    if use_bias is None:
        use_bias = normalizer_fn is None

    # get the specification of outputs and parameters
    kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size)
    kernel_shape = kernel_size + (out_channels, in_channels)
    bias_shape = (out_channels, )

    given_h, given_w = None, None
    given_output_shape = output_shape

    if is_tensor_object(given_output_shape):
        given_output_shape = tf.convert_to_tensor(given_output_shape)
    elif given_output_shape is not None:
        given_h, given_w = given_output_shape

    # validate the parameters
    if kernel is not None:
        kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype)
        kernel = kernel_spec.validate('kernel', kernel)
    if kernel_initializer is None:
        kernel_initializer = default_kernel_initializer(weight_norm)
    if bias is not None:
        bias_spec = ParamSpec(shape=bias_shape, dtype=dtype)
        bias = bias_spec.validate('bias', bias)

    # the main part of the conv2d layer
    with tf.variable_scope(scope, default_name=name or 'deconv2d'):
        with tf.name_scope('output_shape'):
            # detect the input shape and axis arrangements
            input_shape = get_static_shape(input)
            if channels_last:
                c_axis, h_axis, w_axis = -1, -3, -2
            else:
                c_axis, h_axis, w_axis = -3, -2, -1

            output_shape = [None, None, None, None]
            output_shape[c_axis] = out_channels
            if given_output_shape is None:
                if input_shape[h_axis] is not None:
                    output_shape[h_axis] = get_deconv_output_length(
                        input_shape[h_axis], kernel_shape[0], strides[h_axis],
                        padding)
                if input_shape[w_axis] is not None:
                    output_shape[w_axis] = get_deconv_output_length(
                        input_shape[w_axis], kernel_shape[1], strides[w_axis],
                        padding)
            else:
                if not is_tensor_object(given_output_shape):
                    output_shape[h_axis] = given_h
                    output_shape[w_axis] = given_w

            # infer the batch shape in 4-d
            batch_shape = input_shape[:-3]
            if None not in batch_shape:
                output_shape[0] = int(np.prod(batch_shape))

            # now the static output shape is ready
            output_static_shape = tf.TensorShape(output_shape)

            # prepare for the dynamic batch shape
            if output_shape[0] is None:
                output_shape[0] = tf.reduce_prod(get_shape(input)[:-3])

            # prepare for the dynamic spatial dimensions
            if output_shape[h_axis] is None or output_shape[w_axis] is None:
                if given_output_shape is None:
                    input_shape = get_shape(input)
                    if output_shape[h_axis] is None:
                        output_shape[h_axis] = get_deconv_output_length(
                            input_shape[h_axis], kernel_shape[0],
                            strides[h_axis], padding)
                    if output_shape[w_axis] is None:
                        output_shape[w_axis] = get_deconv_output_length(
                            input_shape[w_axis], kernel_shape[1],
                            strides[w_axis], padding)
                else:
                    assert (is_tensor_object(given_output_shape))
                    with assert_deps([
                            assert_rank(given_output_shape, 1),
                            assert_scalar_equal(tf.size(given_output_shape), 2)
                    ]):
                        output_shape[h_axis] = given_output_shape[0]
                        output_shape[w_axis] = given_output_shape[1]

            # compose the final dynamic shape
            if any(is_tensor_object(s) for s in output_shape):
                output_shape = tf.stack(output_shape)
            else:
                output_shape = tuple(output_shape)

        # create the variables
        if kernel is None:
            kernel = model_variable('kernel',
                                    shape=kernel_shape,
                                    dtype=dtype,
                                    initializer=kernel_initializer,
                                    regularizer=kernel_regularizer,
                                    constraint=kernel_constraint,
                                    trainable=trainable)

        if weight_norm_fn is not None:
            kernel = weight_norm_fn(kernel)

        maybe_add_histogram(kernel, 'kernel')
        kernel = maybe_check_numerics(kernel, 'kernel')

        if use_bias and bias is None:
            bias = model_variable('bias',
                                  shape=bias_shape,
                                  initializer=bias_initializer,
                                  regularizer=bias_regularizer,
                                  constraint=bias_constraint,
                                  trainable=trainable)
            maybe_add_histogram(bias, 'bias')
            bias = maybe_check_numerics(bias, 'bias')

        # flatten to 4d
        output, s1, s2 = flatten_to_ndims(input, 4)

        # do convolution or deconvolution
        output = tf.nn.conv2d_transpose(value=output,
                                        filter=kernel,
                                        output_shape=output_shape,
                                        strides=strides,
                                        padding=padding,
                                        data_format=data_format)
        if output_static_shape is not None:
            output.set_shape(output_static_shape)

        # add bias
        if use_bias:
            output = tf.nn.bias_add(output, bias, data_format=data_format)

        # apply the normalization function if specified
        if normalizer_fn is not None:
            output = normalizer_fn(output)

        # split into halves if gated
        if gated:
            output, gate = tf.split(output, 2, axis=c_axis)

        # apply the activation function if specified
        if activation_fn is not None:
            output = activation_fn(output)

        # apply the gate if required
        if gated:
            output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate')

        # unflatten back to original shape
        output = unflatten_from_ndims(output, s1, s2)

        maybe_add_histogram(output, 'output')
        output = maybe_check_numerics(output, 'output')

    return output
コード例 #18
0
    def log_prob(self, given, group_ndims=0, name=None):
        given = tf.convert_to_tensor(given)

        with tf.name_scope('DiscretizedLogistic.log_prob', values=[given]):
            # inv_scale = 1. / scale
            inv_scale = maybe_check_numerics(
                tf.exp(-self.log_scale, name='inv_scale'), 'inv_scale')
            # half_bin = bin_size / 2
            half_bin = self._bin_size * .5
            # delta = bin_size / scale, half_delta = delta / 2
            half_delta = half_bin * inv_scale
            # log(delta) = log(bin_size) - log(scale)
            log_delta = tf.log(self._bin_size) - self.log_scale

            x_mid = (given - self.mean) * inv_scale
            x_low = x_mid - half_delta
            x_high = x_mid + half_delta

            cdf_low = tf.sigmoid(x_low, name='cdf_low')
            cdf_high = tf.sigmoid(x_high, name='cdf_high')

            # the middle bins cases:
            #   log(sigmoid(x_high) - sigmoid(x_low))
            # but in extreme cases where `sigmoid(x_high) - sigmoid(x_low)`
            # is very small, we use an alternative form, as in PixelCNN++.
            cdf_delta = cdf_high - cdf_low
            middle_bins_pdf = tf.where(
                cdf_delta > self._epsilon,
                # to avoid NaNs pollute the select statement, we have to use
                # `maximum(cdf_delta, 1e-12)`
                tf.log(tf.maximum(cdf_delta, 1e-12)),
                # the alternative form.  basically it can be derived by using
                # the mean value theorem for integration.
                x_mid + log_delta - 2. * tf.nn.softplus(x_mid)
            )
            log_prob = maybe_check_numerics(middle_bins_pdf, 'middle_bins_pdf')

            # broadcasted given, shape == x_mid
            broadcast_given = broadcast_to_shape(given, get_shape(x_mid))

            # the left-edge bin case
            #   log(sigmoid(x_high) - sigmoid(-infinity))
            if self._biased_edges and self.min_val is not None:
                left_edge = self._min_val + half_bin
                left_edge_pdf = maybe_check_numerics(
                    -tf.nn.softplus(-x_high), 'left_edge_pdf')
                log_prob = tf.where(
                    broadcast_given < left_edge, left_edge_pdf, log_prob)

            # the right-edge bin case
            #   log(sigmoid(infinity) - sigmoid(x_low))
            if self._biased_edges and self.max_val is not None:
                right_edge = self._max_val - half_bin
                right_edge_pdf = maybe_check_numerics(
                    -tf.nn.softplus(x_low), 'right_edge_pdf')
                log_prob = tf.where(
                    broadcast_given >= right_edge, right_edge_pdf, log_prob)

            # now reduce the group_ndims
            log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims)

        return log_prob
コード例 #19
0
ファイル: branch.py プロジェクト: shliujing/tfsnippet
    def _transform(self, x, compute_y, compute_log_det):
        # split the input x
        x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right],
                          axis=self._split_axis)
        do_compute_y = compute_y or (self._y_n_left is None)

        # apply the left transformation
        y1, log_det1 = self._left.transform(x1,
                                            compute_y=do_compute_y,
                                            compute_log_det=compute_log_det)

        # apply the right transformation
        if self._right is not None:
            y2, log_det2 = self._right.transform(
                x2, compute_y=do_compute_y, compute_log_det=compute_log_det)
        else:
            y2, log_det2 = x2, None

        # check the outputs
        y1_shape = get_static_shape(y1)
        y2_shape = get_static_shape(y2)

        if len(y1_shape) != len(y2_shape):
            raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} '
                               'vs y_right {}'.format(y1, y2))

        # build the y spec if not built
        join_axis = self._join_axis

        if self._y_n_left is None:
            # resolve the join axis
            if join_axis is None:
                join_axis = self._split_axis
            if join_axis < 0:
                join_axis += len(y1_shape)

            if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims:
                raise ValueError(
                    '`join_axis` out of range, or not covered by `y_value_ndims'
                    '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}'
                    .format(self._split_axis, self.y_value_ndims, y1, y2))
            join_axis -= len(y1_shape)

            err_msg = (
                '`y_left.shape[join_axis] + y_right.shape[join_axis]` must '
                'be at least 2: y_left {}, y_right {} axis {}.'.format(
                    y1, y2, join_axis))
            y_n_left = y1_shape[join_axis]
            y_n_right = y2_shape[join_axis]
            if y_n_left is not None and y_n_right is not None:
                y_n_features = y_n_left + y_n_right
                assert (y_n_features >= 2)
            else:
                y_n_left = get_shape(y1)[join_axis]
                y_n_right = get_shape(y2)[join_axis]
                y_n_features = None
                with assert_deps([
                        tf.assert_greater_equal(y_n_left + y_n_right,
                                                2,
                                                message=err_msg)
                ]) as asserted:
                    if asserted:  # pragma: no cover
                        y_n_left = tf.identity(y_n_left)
                        y_n_right = tf.identity(y_n_right)

            self._join_axis = join_axis
            self._y_n_left = y_n_left
            self._y_n_right = y_n_right

            # build the y spec
            dtype = self._x_input_spec.dtype
            y_shape_spec = ['?'] * self.y_value_ndims
            if y_n_features is not None:
                y_shape_spec[self._split_axis] = y_n_features
            assert (not self.require_batch_dims)
            y_shape_spec = ['...'] + y_shape_spec
            self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)

        assert (join_axis is not None)

        # compute y
        y = None
        if compute_y:
            y = tf.concat([y1, y2], axis=join_axis)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = log_det1
            if log_det2 is not None:
                log_det = sum_log_det([log_det, log_det2])

        return y, log_det
コード例 #20
0
ファイル: shifting.py プロジェクト: shliujing/tfsnippet
 def get_dynamic_shape():
     if cached[0] is None:
         cached[0] = get_shape(input)
     return cached[0]
コード例 #21
0
    def _transform_or_inverse_transform(self,
                                        x,
                                        compute_y,
                                        compute_log_det,
                                        reverse=False):
        # Since the transform and inverse_transform are too similar, we
        # just implement these two methods by one super method, controlled
        # by `reverse == True/False`.

        # check the argument
        shape = get_static_shape(x)
        assert (len(shape) >= self.value_ndims)  # checked in `BaseFlow`

        # split the tensor
        x1, x2, n2 = self._split(x)

        # compute the scale and shift
        shift, pre_scale = self._shift_and_scale_fn(x1, n2)
        if self._scale_type is not None and pre_scale is None:
            raise RuntimeError('`scale_type` != None, but no scale is '
                               'computed.')
        elif self._scale_type is None and pre_scale is not None:
            raise RuntimeError('`scale_type` == None, but scale is computed.')

        if pre_scale is not None:
            pre_scale = self._check_scale_or_shift_shape(
                'scale', pre_scale, x2)
        shift = self._check_scale_or_shift_shape('shift', shift, x2)

        # derive the scale class
        if self._scale_type == 'sigmoid':
            scale = SigmoidScale(pre_scale + self._sigmoid_scale_bias,
                                 self._epsilon)
        elif self._scale_type == 'exp':
            scale = ExpScale(pre_scale, self._epsilon)
        elif self._scale_type == 'linear':
            scale = LinearScale(pre_scale, self._epsilon)
        else:
            assert (self._scale_type is None)
            scale = None

        # compute y
        y = None
        if compute_y:
            y1 = x1
            if reverse:
                y2 = x2
                if scale is not None:
                    y2 = y2 / scale
                y2 -= shift
            else:
                y2 = x2 + shift
                if scale is not None:
                    y2 = y2 * scale
            y = self._unsplit(y1, y2)

        # compute log_det
        log_det = None
        if compute_log_det:
            assert (self.value_ndims >= 0)  # checked in `_build`
            if scale is not None:
                log_det = tf.reduce_sum(
                    scale.neg_log_scale() if reverse else scale.log_scale(),
                    axis=list(range(-self.value_ndims, 0)))
            else:
                log_det = ZeroLogDet(
                    get_shape(x)[:-self.value_ndims], x.dtype.base_dtype)

        return y, log_det
コード例 #22
0
ファイル: discretized.py プロジェクト: shliujing/tfsnippet
    def log_prob(self, given, group_ndims=0, name=None):
        given = tf.convert_to_tensor(given)

        with tf.name_scope('DiscretizedLogistic.log_prob', values=[given]):
            if self.discretize_given:
                given = self._discretize(given)

            # inv_scale = 1. / exp(log_scale)
            inv_scale = maybe_check_numerics(
                tf.exp(-self.log_scale, name='inv_scale'), 'inv_scale')
            # half_bin = bin_size / 2
            half_bin = self.bin_size * .5
            # delta = bin_size / scale, half_delta = delta / 2
            half_delta = half_bin * inv_scale

            # x_mid = (x - mean) / scale
            x_mid = (given - self.mean) * inv_scale

            # x_low = (x - mean - bin_size * 0.5) / scale
            x_low = x_mid - half_delta
            # x_high = (x - mean + bin_size * 0.5) / scale
            x_high = x_mid + half_delta

            cdf_low = tf.sigmoid(x_low, name='cdf_low')
            cdf_high = tf.sigmoid(x_high, name='cdf_high')
            cdf_delta = cdf_high - cdf_low

            # the middle bins cases:
            #   log(sigmoid(x_high) - sigmoid(x_low))
            # middle_bins_pdf = tf.log(cdf_delta + self._epsilon)
            middle_bins_pdf = tf.log(tf.maximum(cdf_delta, self._epsilon))

            # with tf.control_dependencies([
            #             tf.print(
            #                 'x_mid: ', tf.reduce_mean(x_mid),
            #                 'x_low: ', tf.reduce_mean(x_low),
            #                 'x_high: ', tf.reduce_mean(x_high),
            #                 'diff: ', tf.reduce_mean((given - self.mean)),
            #                 'mean: ', tf.reduce_mean(self.mean),
            #                 'scale: ', tf.reduce_mean(tf.exp(self.log_scale)),
            #                 'half_delta: ', tf.reduce_mean(half_delta),
            #                 'cdf_delta: ', tf.reduce_mean(cdf_delta),
            #                 'log_pdf: ', tf.reduce_mean(middle_bins_pdf)
            #             )
            #         ]):
            #     middle_bins_pdf = tf.identity(middle_bins_pdf)

            # # but in extreme cases where `sigmoid(x_high) - sigmoid(x_low)`
            # # is very small, we use an alternative form, as in PixelCNN++.
            # log_delta = tf.log(self.bin_size) - self.log_scale
            # middle_bins_pdf = tf.where(
            #     cdf_delta > self._epsilon,
            #     # to avoid NaNs pollute the select statement, we have to use
            #     # `maximum(cdf_delta, 1e-12)`
            #     tf.log(tf.maximum(cdf_delta, 1e-12)),
            #     # the alternative form.  basically it can be derived by using
            #     # the mean value theorem for integration.
            #     x_mid + log_delta - 2. * tf.nn.softplus(x_mid)
            # )

            log_prob = maybe_check_numerics(middle_bins_pdf, 'middle_bins_pdf')

            if self.biased_edges and self.min_val is not None:
                # broadcasted given, shape == x_mid
                broadcast_given = broadcast_to_shape(given, get_shape(x_low))

                # the left-edge bin case
                #   log(sigmoid(x_high) - sigmoid(-infinity))
                left_edge = self.min_val + half_bin
                left_edge_pdf = maybe_check_numerics(-tf.nn.softplus(-x_high),
                                                     'left_edge_pdf')
                log_prob = tf.where(tf.less(broadcast_given, left_edge),
                                    left_edge_pdf, log_prob)

                # the right-edge bin case
                #   log(sigmoid(infinity) - sigmoid(x_low))
                right_edge = self.max_val - half_bin
                right_edge_pdf = maybe_check_numerics(-tf.nn.softplus(x_low),
                                                      'right_edge_pdf')
                log_prob = tf.where(
                    tf.greater_equal(broadcast_given, right_edge),
                    right_edge_pdf, log_prob)

            # now reduce the group_ndims
            log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims)

        return log_prob