コード例 #1
0
    def test_prepend_dims(self):
        with pytest.raises(ValueError, match='`ndims` must be >= 0: got -1'):
            _ = prepend_dims(tf.constant(0.), ndims=-1)

        x = tf.zeros([2, 3])
        self.assertIs(prepend_dims(x, ndims=0), x)

        with self.test_session() as sess:
            # test static shape
            x = np.random.normal(size=[2, 3])
            y = prepend_dims(x, ndims=1)
            self.assertEqual(get_static_shape(y), (1, 2, 3))
            np.testing.assert_allclose(sess.run(y), x.reshape([1, 2, 3]))

            # test partially dynamic shape
            t = tf.placeholder(shape=[2, None], dtype=tf.float64)
            y = prepend_dims(t, ndims=2)
            self.assertEqual(get_static_shape(y), (1, 1, 2, None))
            np.testing.assert_allclose(sess.run(y, feed_dict={t: x}),
                                       x.reshape([1, 1, 2, 3]))

            # test fully dynamic shape
            t = tf.placeholder(shape=None, dtype=tf.float64)
            y = prepend_dims(t, ndims=3)
            self.assertEqual(get_static_shape(y), None)
            np.testing.assert_allclose(sess.run(y, feed_dict={t: x}),
                                       x.reshape([1, 1, 1, 2, 3]))
コード例 #2
0
        def run_check(x, k, dynamic_shape):
            if dynamic_shape:
                t = tf.placeholder(tf.int32, [None] * len(x.shape))
                run = lambda sess, *args: sess.run(*args, feed_dict={t: x})
            else:
                t = tf.constant(x, dtype=tf.int32)
                run = lambda sess, *args: sess.run(*args)

            if len(x.shape) == k:
                self.assertEqual(flatten_to_ndims(t, k), (t, None, None))
                self.assertEqual(unflatten_from_ndims(t, None, None), t)

            else:
                if k == 1:
                    front_shape = tuple(x.shape)
                    static_front_shape = get_static_shape(t)
                    xx = x.reshape([-1])
                else:
                    front_shape = tuple(x.shape)[:-(k - 1)]
                    static_front_shape = get_static_shape(t)[:-(k - 1)]
                    xx = x.reshape([-1] + list(x.shape)[-(k - 1):])

                with self.test_session() as sess:
                    tt, s1, s2 = flatten_to_ndims(t, k)
                    self.assertEqual(s1, static_front_shape)
                    if not dynamic_shape:
                        self.assertEqual(s2, front_shape)
                    else:
                        self.assertEqual(tuple(run(sess, s2)), front_shape)
                    np.testing.assert_equal(run(sess, tt), xx)
                    np.testing.assert_equal(
                        run(sess, unflatten_from_ndims(tt, s1, s2)), x)
コード例 #3
0
 def _unsplit(self, x1, x2):
     n1 = self._n_features // 2
     n2 = self._n_features - n1
     if self._secondary:
         x1, x2 = x2, x1
     assert (get_static_shape(x1)[self.axis] == n1)
     assert (get_static_shape(x2)[self.axis] == n2)
     return tf.concat([x1, x2], axis=self.axis)
コード例 #4
0
 def _check_scale_or_shift_shape(self, name, tensor, x2):
     assert_op = assert_shape_equal(
         tensor,
         x2,
         message='`{}.shape` expected to be {}, but got {}'.format(
             name, get_static_shape(x2), get_static_shape(tensor)))
     with assert_deps([assert_op]) as asserted:
         if asserted:  # pragma: no cover
             tensor = tf.identity(tensor)
     return tensor
コード例 #5
0
ファイル: utils.py プロジェクト: shliujing/tfsnippet
def validate_conv2d_input(input, channels_last, arg_name='input'):
    """
    Validate the input for 2-d convolution.

    Args:
        input: The input tensor, must be at least 4-d.
        channels_last (bool): Whether or not the last dimension is the
            channels dimension? (i.e., `data_format` is "NHWC")
        arg_name (str): Name of the input argument.

    Returns:
        (tf.Tensor, int, str): The validated input tensor, the number of input
            channels, and the data format.
    """
    if channels_last:
        input_spec = InputSpec(shape=('...', '?', '?', '?', '*'))
        channel_axis = -1
        data_format = 'NHWC'
    else:
        input_spec = InputSpec(shape=('...', '?', '*', '?', '?'))
        channel_axis = -3
        data_format = 'NCHW'
    input = input_spec.validate(arg_name, input)
    input_shape = get_static_shape(input)
    in_channels = input_shape[channel_axis]

    return input, in_channels, data_format
コード例 #6
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
def prepend_dims(x, ndims=1, name=None):
    """
    Prepend `[1] * ndims` to the beginning of the shape of `x`.

    Args:
        x: The tensor `x`.
        ndims: Number of `1` to prepend.

    Returns:
        tf.Tensor: The tensor with prepended dimensions.
    """
    ndims = int(ndims)
    if ndims < 0:
        raise ValueError('`ndims` must be >= 0: got {}'.format(ndims))

    x = tf.convert_to_tensor(x)
    if ndims == 0:
        return x

    with tf.name_scope(name, default_name='prepend_dims', values=[x]):
        static_shape = get_static_shape(x)
        if static_shape is not None:
            static_shape = tf.TensorShape([1] * ndims + list(static_shape))

        dynamic_shape = concat_shapes([[1] * ndims, get_shape(x)])

        y = tf.reshape(x, dynamic_shape)
        if static_shape is not None:
            y.set_shape(static_shape)

        return y
コード例 #7
0
ファイル: planar_nf.py プロジェクト: shliujing/tfsnippet
    def _build(self, input=None):
        dtype = input.dtype.base_dtype
        n_units = get_static_shape(input)[self.axis]

        w = model_variable('w',
                           shape=[1, n_units],
                           dtype=dtype,
                           initializer=self._w_initializer,
                           regularizer=self._w_regularizer,
                           trainable=self._trainable)
        b = model_variable('b',
                           shape=[1],
                           dtype=dtype,
                           initializer=self._b_initializer,
                           regularizer=self._b_regularizer,
                           trainable=self._trainable)
        u = model_variable('u',
                           shape=[1, n_units],
                           dtype=dtype,
                           initializer=self._u_initializer,
                           regularizer=self._u_regularizer,
                           trainable=self._trainable)
        wu = tf.matmul(w, u, transpose_b=True)  # wu.shape == [1]
        u_hat = u + (-1 + tf.nn.softplus(wu) - wu) * \
            w / tf.reduce_sum(tf.square(w))  # shape == [1, n_units]

        self._w, self._b, self._u, self._u_hat = w, b, u, u_hat
コード例 #8
0
        def check(x, shape, x_ph=None, shape_ph=None, static_shape=None):
            # compute the expected answer
            try:
                y = x * np.ones(tuple(shape), dtype=x.dtype)
                if y.shape != shape:
                    raise ValueError()
            except ValueError:
                y = None

            # call the function and get output
            feed_dict = {}
            if x_ph is not None:
                feed_dict[x_ph] = x
                x = x_ph
            if shape_ph is not None:
                feed_dict[shape_ph] = np.asarray(shape)
                shape = shape_ph

            if y is None:
                with pytest.raises(Exception,
                                   match='`x` cannot be broadcasted '
                                   'to match `shape`'):
                    t = broadcast_to_shape_strict(x, shape)
                    _ = sess.run(t, feed_dict=feed_dict)
            else:
                t = broadcast_to_shape_strict(x, shape)
                if static_shape is not None:
                    self.assertTupleEqual(get_static_shape(t), static_shape)

                out = sess.run(t, feed_dict=feed_dict)
                self.assertTupleEqual(out.shape, y.shape)
                np.testing.assert_equal(out, y)
コード例 #9
0
        def check(x,
                  ndims,
                  shape,
                  expected_shape,
                  static_shape=None,
                  x_ph=None,
                  shape_ph=None):
            # compute the answer
            assert (len(x.shape) >= ndims)
            if ndims > 0:
                y = np.reshape(x, x.shape[:-ndims] + tuple(shape))
            else:
                y = np.reshape(x, x.shape + tuple(shape))
            self.assertEqual(y.shape, expected_shape)

            # validate the output
            feed_dict = {}
            if x_ph is not None:
                feed_dict[x_ph] = x
                x = x_ph
            if shape_ph is not None:
                feed_dict[shape_ph] = shape
                shape = shape_ph

            y_tensor = reshape_tail(x, ndims, shape)
            if static_shape is not None:
                self.assertTupleEqual(get_static_shape(y_tensor), static_shape)
            y_out = sess.run(y_tensor, feed_dict=feed_dict)

            self.assertTupleEqual(y_out.shape, y.shape)
            np.testing.assert_equal(y_out, y)
コード例 #10
0
 def _build(self, input=None):
     n_features = get_static_shape(input)[self.axis]
     if n_features < 2:
         raise ValueError('The feature axis of `input` must be at least 2: '
                          'got {}, input {}, axis {}.'.format(
                              n_features, input, self.axis))
     self._n_features = n_features
コード例 #11
0
    def test_sample(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        flow = QuadraticFlow(2., 5.)
        distrib = FlowDistribution(normal, flow)

        # test ordinary sample, is_reparameterized = None
        y = distrib.sample(n_samples=5)
        self.assertTrue(y.is_reparameterized)
        grad = tf.gradients(y * 1., mean)[0]
        self.assertIsNotNone(grad)
        self.assertEqual(get_static_shape(y), (5, 3))
        self.assertIsNotNone(y._self_log_prob)

        x, log_det = flow.inverse_transform(y)
        log_py = normal.log_prob(x) + log_det

        with self.test_session() as sess:
            np.testing.assert_allclose(*sess.run([log_py, y.log_prob()]),
                                       rtol=1e-5)

        # test stop gradient sample, is_reparameterized = False
        y = distrib.sample(n_samples=5, is_reparameterized=False)
        self.assertFalse(y.is_reparameterized)
        grad = tf.gradients(y * 1., mean)[0]
        self.assertIsNone(grad)
コード例 #12
0
    def _build(self, input=None):
        # check the input.
        input = tf.convert_to_tensor(input)
        dtype = input.dtype.base_dtype
        shape = get_static_shape(input)

        # These facts should have been checked in `BaseFlow.build`.
        assert (shape is not None)
        assert (len(shape) >= self.value_ndims)

        # compute var spec and input spec
        min_axis = min(self.axis)
        shape_spec = [None] * len(shape)
        for a in self.axis:
            shape_spec[a] = shape[a]
        shape_spec = shape_spec[min_axis:]
        assert (not not shape_spec)
        assert (self.value_ndims >= len(shape_spec))

        self._y_input_spec = self._x_input_spec = InputSpec(
            shape=(('...', ) + ('?', ) * (self.value_ndims - len(shape_spec)) +
                   tuple(shape_spec)),
            dtype=dtype)
        # the shape of variables must only have necessary dimensions,
        # such that we can switch freely between `channels_last = True`
        # (in which case `input.shape = (..., *,)`, and `channels_last = False`
        # (in which case `input.shape = (..., *, 1, 1)`.
        self._var_shape = tuple(s for s in shape_spec if s is not None)
        # and we still need to compute the aligned variable shape, such that
        # we can immediately reshape the variables into this aligned shape,
        # then compute `scale * input + bias`.
        self._var_shape_aligned = tuple(s or 1 for s in shape_spec)
        self._var_spec = ParamSpec(self._var_shape)

        # validate the input
        self._x_input_spec.validate('input', input)

        # build the variables
        self._bias = model_variable('bias',
                                    dtype=dtype,
                                    shape=self._var_shape,
                                    regularizer=self._bias_regularizer,
                                    constraint=self._bias_constraint,
                                    trainable=self._trainable)
        if self._scale_type == 'exp':
            self._pre_scale = model_variable(
                'log_scale',
                dtype=dtype,
                shape=self._var_shape,
                regularizer=self._log_scale_regularizer,
                constraint=self._log_scale_constraint,
                trainable=self._trainable)
        else:
            self._pre_scale = model_variable(
                'scale',
                dtype=dtype,
                shape=self._var_shape,
                regularizer=self._scale_regularizer,
                constraint=self._scale_constraint,
                trainable=self._trainable)
コード例 #13
0
ファイル: test_flow.py プロジェクト: shliujing/tfsnippet
    def test_sample_value_and_group_ndims(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))

        with self.test_session() as sess:
            # test value_ndims = 0, group_ndims = 1
            flow = QuadraticFlow(2., 5.)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 0)

            y = distrib.sample(n_samples=5, group_ndims=1)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5, 3))
            log_py = tf.reduce_sum(normal.log_prob(x) + log_det, axis=-1)

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)

            # test value_ndims = 1, group_ndims = 0
            flow = QuadraticFlow(2., 5., value_ndims=1)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            y = distrib.sample(n_samples=5, group_ndims=0)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5,))
            log_py = log_det + tf.reduce_sum(normal.log_prob(x), axis=-1)

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)

            # test value_ndims = 1, group_ndims = 1
            flow = QuadraticFlow(2., 5., value_ndims=1)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            y = distrib.sample(n_samples=5, group_ndims=1)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5,))
            log_py = tf.reduce_sum(
                log_det + tf.reduce_sum(normal.log_prob(x), axis=-1))

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)
コード例 #14
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype
        assert (shape is not None and len(shape) >= self.x_value_ndims)

        # re-build the x input spec
        x_shape_spec = []
        if self.x_value_ndims > 0:
            x_shape_spec = list(shape)[-self.x_value_ndims:]
        if self.require_batch_dims:
            x_shape_spec = ['?'] + x_shape_spec
        x_shape_spec = ['...'] + x_shape_spec
        self._x_input_spec = InputSpec(shape=x_shape_spec, dtype=dtype)

        # infer the dynamic value shape of x, and store it for inverse transform
        x_value_shape = []
        if self.x_value_ndims > 0:
            x_value_shape = list(shape)[-self.x_value_ndims:]

        neg_one_count = 0
        for i, s in enumerate(x_value_shape):
            if s is None:
                if neg_one_count > 0:
                    x_value_shape = get_shape(input)
                    if self.x_value_ndims > 0:
                        x_value_shape = x_value_shape[-self.x_value_ndims:]
                    break
                else:
                    x_value_shape[i] = -1
                    neg_one_count += 1

        self._x_value_shape = x_value_shape

        # now infer the y value shape according to new info obtained from x
        y_value_shape = list(self._y_value_shape)
        if isinstance(x_value_shape, list) and -1 not in x_value_shape:
            x_value_size = int(np.prod(x_value_shape))
            y_value_size = int(np.prod([s for s in y_value_shape if s != -1]))

            if (-1 in y_value_shape and x_value_size % y_value_size != 0) or \
                    (-1 not in y_value_shape and x_value_size != y_value_size):
                raise ValueError(
                    'Cannot reshape the tail dimensions of `x` into `y`: '
                    'x value shape {!r}, y value shape {!r}.'.format(
                        x_value_shape, y_value_shape))

            if -1 in y_value_shape:
                y_value_shape[y_value_shape.index(-1)] = \
                    x_value_size // y_value_size

            assert (-1 not in y_value_shape)
            self._y_value_shape = tuple(y_value_shape)

        # re-build the y input spec
        y_shape_spec = list(y_value_shape)
        if self.require_batch_dims:
            y_shape_spec = ['?'] + y_shape_spec
        y_shape_spec = ['...'] + y_shape_spec
        self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)
コード例 #15
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
def broadcast_to_shape_strict(x, shape, name=None):
    """
    Broadcast `x` to match `shape`.

    This method requires `rank(x)` to be less than or equal to `len(shape)`.
    You may use :func:`broadcast_to_shape` instead, to allow the cases where
    ``rank(x) > len(shape)``.

    Args:
        x: A tensor.
        shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape.

    Returns:
        tf.Tensor: The broadcasted tensor.
    """
    # check the parameters
    x = tf.convert_to_tensor(x)
    x_shape = get_static_shape(x)
    ns_values = [x]
    if is_tensor_object(shape):
        shape = tf.convert_to_tensor(shape)
        ns_values.append(shape)
    else:
        shape = tuple(int(s) for s in shape)

    with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values):
        cannot_broadcast_msg = (
            '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'.
            format(x, shape))

        # assert ``rank(x) <= len(shape)``
        if isinstance(shape, tuple) and x_shape is not None:
            if len(x_shape) > len(shape):
                raise ValueError(cannot_broadcast_msg)
        elif isinstance(shape, tuple):
            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         len(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)
        else:
            with assert_deps(
                [assert_rank(shape, 1,
                             message=cannot_broadcast_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    shape = tf.identity(shape)

            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         tf.size(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)

        # do broadcast
        return broadcast_to_shape(x, shape)
コード例 #16
0
ファイル: rearrangement.py プロジェクト: shliujing/tfsnippet
    def _transform_or_inverse_transform(self, x, compute_y, compute_log_det,
                                        permutation):
        assert (0 > self.axis >= -self.value_ndims >= -len(get_static_shape(x)))
        assert (get_static_shape(x)[self.axis] == self._n_features)

        # compute y
        y = None
        if compute_y:
            y = tf.gather(x, permutation, axis=self.axis)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = ZeroLogDet(get_shape(x)[:-self.value_ndims],
                                 x.dtype.base_dtype)

        return y, log_det
コード例 #17
0
ファイル: base.py プロジェクト: mengyuan404/tfsnippet
    def _build_input_spec(self, input):
        super(FeatureMappingFlow, self)._build_input_spec(input)

        dtype = input.dtype.base_dtype
        shape = get_static_shape(input)

        # These facts should have been checked in `BaseFlow.build`.
        assert (shape is not None)
        assert (len(shape) >= self.value_ndims)

        # validate the feature axis, ensure it is covered by `value_ndims`.
        axis = self._axis
        axis_is_int = is_integer(axis)
        if axis_is_int:
            axis = [axis]
        else:
            axis = list(axis)

        for i, a in enumerate(axis):
            if a < 0:
                a += len(shape)
            if a < 0 or a < len(shape) - self.value_ndims:
                raise ValueError('`axis` out of range, or not covered by '
                                 '`value_ndims`: axis {}, value_ndims {}, '
                                 'input {}'.format(self._axis,
                                                   self.value_ndims, input))
            if shape[a] is None:
                raise ValueError('The feature axis of `input` is not '
                                 'deterministic: input {}, axis {}'.format(
                                     input, self._axis))

            # Store the negative axis, such that when new inputs can have more
            # dimensions than this `input`, the axis can still be correctly
            # resolved.
            axis[i] = a - len(shape)

        if axis_is_int:
            assert (len(axis) == 1)
            self._axis = axis[0]
        else:
            axis_len = len(axis)
            axis = tuple(sorted(set(axis)))
            if len(axis) != axis_len:
                raise ValueError(
                    'Duplicated elements after resolving negative '
                    '`axis` with respect to the `input`: '
                    'input {}, axis {}'.format(input, self._axis))
            self._axis = tuple(axis)

        # re-build the input spec
        batch_ndims = int(self.require_batch_dims)
        shape_spec = ['...'] + ['?'] * (self.value_ndims + batch_ndims)
        for a in axis:
            shape_spec[a] = shape[a]
        self._y_input_spec = self._x_input_spec = InputSpec(shape=shape_spec,
                                                            dtype=dtype)
        self._x_input_spec.validate('input', input)
コード例 #18
0
ファイル: rearrangement.py プロジェクト: shliujing/tfsnippet
    def _build(self, input=None):
        n_features = self._n_features = get_static_shape(input)[self.axis]
        permutation = np.arange(n_features, dtype=np.int32)
        self._random_state.shuffle(permutation)

        self._permutation = model_variable(
            'permutation', dtype=tf.int32, initializer=permutation,
            trainable=False
        )
        self._inv_permutation = tf.invert_permutation(self._permutation)
コード例 #19
0
 def _split(self, x):
     n_features = get_static_shape(x)[self.axis]
     assert (self._n_features == n_features)
     n1 = n_features // 2
     n2 = n_features - n1
     x1, x2 = tf.split(x, [n1, n2], self.axis)
     if self._secondary:
         return x2, x1, n1
     else:
         return x1, x2, n2
コード例 #20
0
ファイル: linear.py プロジェクト: shliujing/tfsnippet
    def _build(self, input=None):
        dtype = input.dtype.base_dtype
        n_features = get_static_shape(input)[self.axis]

        self._kernel_matrix = InvertibleMatrix(size=n_features,
                                               strict=self._strict_invertible,
                                               dtype=dtype,
                                               trainable=self._trainable,
                                               random_state=self._random_state,
                                               scope='kernel')
コード例 #21
0
def assert_rank(x, ndims, message=None, name=None):
    """
    Assert the rank of `x` is `ndims`.

    Args:
        x: A tensor.
        ndims (int or tf.Tensor): An integer, or a 0-d integer tensor.
        message: Message to display when assertion failed.

    Returns:
        tf.Operation or None: The TensorFlow assertion operation,
            or None if can be statically asserted.
    """
    if not is_tensor_object(ndims) and get_static_shape(x) is not None:
        ndims = int(ndims)
        x_ndims = len(get_static_shape(x))
        if x_ndims != ndims:
            raise _make_assertion_error('rank(x) == ndims',
                                        '{!r} != {!r}'.format(x_ndims,
                                                              ndims), message)
    else:
        return tf.assert_rank(x, ndims, message=message, name=name)
コード例 #22
0
ファイル: test_utils.py プロジェクト: shliujing/tfsnippet
        def check(input_size, kernel_size, strides, padding):
            output_size = get_deconv_output_length(input_size, kernel_size,
                                                   strides, padding)
            self.assertGreater(output_size, 0)

            # assert input <- output
            x = tf.nn.conv2d(np.zeros([1, output_size, output_size, 1],
                                      dtype=np.float32),
                             filter=np.zeros([kernel_size, kernel_size, 1, 1]),
                             strides=[1, strides, strides, 1],
                             padding=padding.upper(),
                             data_format='NHWC')
            h, w = get_static_shape(x)[1:3]
            self.assertEqual(input_size, h)
コード例 #23
0
ファイル: shape_utils.py プロジェクト: shliujing/tfsnippet
def flatten_to_ndims(x, ndims, name=None):
    """
    Flatten the front dimensions of `x`, such that the resulting tensor
    will have at most `ndims` dimensions.

    Args:
        x (Tensor): The tensor to be flatten.
        ndims (int): The maximum number of dimensions for the resulting tensor.

    Returns:
        (tf.Tensor, tuple[int or None], tuple[int] or tf.Tensor) or (tf.Tensor, None, None):
            (The flatten tensor, the static front shape, and the front shape),
            or (the original tensor, None, None)
    """
    x = tf.convert_to_tensor(x)
    if ndims < 1:
        raise ValueError('`k` must be greater or equal to 1.')
    if not x.get_shape():
        raise ValueError('`x` is required to have known number of '
                         'dimensions.')
    shape = get_static_shape(x)
    if len(shape) < ndims:
        raise ValueError('`k` is {}, but `x` only has rank {}.'.format(
            ndims, len(shape)))
    if len(shape) == ndims:
        return x, None, None

    with tf.name_scope(name, default_name='flatten', values=[x]):
        if ndims == 1:
            static_shape = shape
            if None in shape:
                shape = tf.shape(x)
            return tf.reshape(x, [-1]), static_shape, shape
        else:
            front_shape, back_shape = shape[:-(ndims - 1)], shape[-(ndims -
                                                                    1):]
            static_front_shape = front_shape
            static_back_shape = back_shape
            if None in front_shape or None in back_shape:
                dynamic_shape = tf.shape(x)
                if None in front_shape:
                    front_shape = dynamic_shape[:-(ndims - 1)]
                if None in back_shape:
                    back_shape = dynamic_shape[-(ndims - 1):]
            if isinstance(back_shape, tuple):
                x = tf.reshape(x, [-1] + list(back_shape))
            else:
                x = tf.reshape(x, tf.concat([[-1], back_shape], axis=0))
                x.set_shape(tf.TensorShape([None] + list(static_back_shape)))
            return x, static_front_shape, front_shape
コード例 #24
0
ファイル: branch.py プロジェクト: shliujing/tfsnippet
    def _build(self, input=None):
        shape = get_static_shape(input)
        dtype = input.dtype.base_dtype

        # resolve the split axis
        split_axis = self._split_axis
        if split_axis < 0:
            split_axis += len(shape)
        if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims:
            raise ValueError(
                '`split_axis` out of range, or not covered by `x_value_ndims`: '
                'split_axis {}, x_value_ndims {}, input {}'.format(
                    self._split_axis, self.x_value_ndims, input))
        split_axis -= len(shape)

        err_msg = (
            'The split axis of `input` must be at least 2: input {}, axis {}.'.
            format(input, split_axis))
        if shape[split_axis] is not None:
            x_n_features = shape[split_axis]
            if x_n_features < 2:
                raise ValueError(err_msg)
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left
        else:
            x_n_features = get_shape(input)[split_axis]
            x_n_left = x_n_features // 2
            x_n_right = x_n_features - x_n_left

            with assert_deps(
                [tf.assert_greater_equal(x_n_features, 2,
                                         message=err_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    x_n_left = tf.identity(x_n_left)
                    x_n_right = tf.identity(x_n_right)

            x_n_features = None

        self._split_axis = split_axis
        self._x_n_left = x_n_left
        self._x_n_right = x_n_right

        # build the x spec
        shape_spec = ['?'] * self.x_value_ndims
        if x_n_features is not None:
            shape_spec[self._split_axis] = x_n_features
        assert (not self.require_batch_dims)
        shape_spec = ['...'] + shape_spec
        self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
コード例 #25
0
ファイル: test_planar_nf.py プロジェクト: shliujing/tfsnippet
    def test_u_hat(self):
        n_units = 5
        tf.set_random_seed(1234)
        flow = PlanarNormalizingFlow(name='planar_nf')
        flow.apply(tf.random_normal(shape=[1, n_units]))

        # ensure these parameters exist
        w, b, u, u_hat = flow._w, flow._b, flow._u, flow._u_hat
        for v in [w, u, b, u_hat]:
            self.assertIn('planar_nf/', v.name)

        # ensure these parameters have expected shapes
        self.assertEqual(get_static_shape(w), (1, n_units))
        self.assertEqual(get_static_shape(u), (1, n_units))
        self.assertEqual(get_static_shape(b), (1, ))
        self.assertEqual(get_static_shape(u_hat), (1, n_units))

        with self.test_session() as sess:
            ensure_variables_initialized()
            w, b, u, u_hat = sess.run([w, b, u, u_hat])
            m = lambda a: -1 + np.log(1 + np.exp(a))
            wu = np.dot(w, u.T)  # shape: [1]
            np.testing.assert_allclose(u + w * (m(wu) - wu) / np.sum(w**2),
                                       u_hat)
コード例 #26
0
ファイル: test_flow.py プロジェクト: shliujing/tfsnippet
    def test_log_prob_value_and_group_ndims(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        y = tf.random_normal(shape=[2, 5, 3], dtype=tf.float64)

        with self.test_session() as sess:
            # test value_ndims = 0, group_ndims = 1
            flow = QuadraticFlow(2., 5.)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 0)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5, 3))
            log_py = tf.reduce_sum(normal.log_prob(x) + log_det, axis=-1)

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=1), log_py]),
                rtol=1e-5
            )

            # test value_ndims = 1, group_ndims = 0
            flow = QuadraticFlow(2., 5., value_ndims=1)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5))
            log_py = normal.log_prob(x, group_ndims=1) + log_det

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=0), log_py]),
                rtol=1e-5
            )

            # test value_ndims = 1, group_ndims = 2
            flow = QuadraticFlow(2., 5., value_ndims=1)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5))
            log_py = tf.reduce_sum(
                log_det + tf.reduce_sum(normal.log_prob(x), axis=-1))

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=2), log_py]),
                rtol=1e-5
            )
コード例 #27
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _transform(self, x, compute_y, compute_log_det):
        assert (len(get_static_shape(x)) >= self.x_value_ndims)

        # compute y
        y = None
        if compute_y:
            y = reshape_tail(x, self.x_value_ndims, self._y_value_shape)

        # compute log_det
        log_det = None
        if compute_log_det:
            dst_shape = get_shape(x)
            if self.x_value_ndims > 0:
                dst_shape = dst_shape[:-self.x_value_ndims]
            log_det = ZeroLogDet(dst_shape, dtype=x.dtype.base_dtype)

        return y, log_det
コード例 #28
0
ファイル: reshape.py プロジェクト: mengyuan404/tfsnippet
    def _inverse_transform(self, y, compute_x, compute_log_det):
        assert (len(get_static_shape(y)) >= self.y_value_ndims)

        # compute y
        x = None
        if compute_x:
            x = reshape_tail(y, self.y_value_ndims, self._x_value_shape)

        # compute log_det
        log_det = None
        if compute_log_det:
            dst_shape = get_shape(y)
            if self.y_value_ndims > 0:
                dst_shape = dst_shape[:-self.y_value_ndims]
            log_det = ZeroLogDet(dst_shape, dtype=y.dtype.base_dtype)

        return x, log_det
コード例 #29
0
ファイル: linear.py プロジェクト: shliujing/tfsnippet
    def _transform(self, x, compute_y, compute_log_det):
        # compute y
        y = None
        if compute_y:
            n_features = get_static_shape(x)[self.axis]
            y = dense(x,
                      n_features,
                      kernel=self._kernel_matrix.matrix,
                      use_bias=False)

        # compute log_det
        log_det = None
        if compute_log_det:
            log_det = apply_log_det_factor(self._kernel_matrix.log_det, x,
                                           self.axis, self.value_ndims)
            log_det = broadcast_log_det_against_input(
                log_det, x, value_ndims=self.value_ndims)

        return y, log_det
コード例 #30
0
def classification_accuracy(y_pred, y_true, name=None):
    """
    Compute the classification accuracy for `y_pred` and `y_true`.

    Args:
        y_pred: The predicted labels.
        y_true: The ground truth labels.  Its shape must match `y_pred`.

    Returns:
        tf.Tensor: The accuracy.
    """
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = InputSpec(shape=get_static_shape(y_pred)). \
        validate('y_true', y_true)
    with tf.name_scope(name,
                       default_name='classification_accuracy',
                       values=[y_pred, y_true]):
        return tf.reduce_mean(
            tf.cast(tf.equal(y_pred, y_true), dtype=tf.float32))