示例#1
0
    def __init__(self, axis, value_ndims, **kwargs):
        """
        Construct a new :class:`FeatureMappingFlow`.

        Args:
            axis (int or Iterable[int]): The feature axis/axes, on which to
                apply the transformation.
            value_ndims (int): Number of value dimensions in both `x` and `y`.
                `x.ndims - value_ndims == log_det.ndims` and
                `y.ndims - value_ndims == log_det.ndims`.
            \\**kwargs: Other named arguments passed to :class:`BaseFlow`.
        """
        # check the arguments
        if is_integer(axis):
            axis = int(axis)
        else:
            axis = tuple(int(a) for a in axis)
            if not axis:
                raise ValueError('`axis` must not be empty.')

        if 'x_value_ndims' in kwargs or 'y_value_ndims' in kwargs:
            raise ValueError('Specifying `x_value_ndims` or `y_value_ndims` '
                             'for a `FeatureMappingFlow` is not allowed.')
        value_ndims = int(value_ndims)

        # construct the layer
        super(FeatureMappingFlow, self).__init__(x_value_ndims=value_ndims,
                                                 y_value_ndims=value_ndims,
                                                 **kwargs)
        self._axis = axis
示例#2
0
    def __init__(self,
                 hidden_net_p_x_z,
                 hidden_net_q_z_x,
                 x_dims,
                 z_dims,
                 std_epsilon=1e-4,
                 name=None,
                 scope=None):
        if not is_integer(x_dims) or x_dims <= 0:
            raise ValueError('`x_dims`必须为正整数')
        if not is_integer(z_dims) or z_dims <= 0:
            raise ValueError('`z_dims`必须为正整数')

        super(Donut, self).__init__(name=name, scope=scope)
        with reopen_variable_scope(self.variable_scope):
            # 基于VAE构造
            self._vae = VAE(
                # p(z):均值和标准差都为z维数量大小的全零数组的一元正态分布
                p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])),
                # p(x|h(z)):一元正态分布
                p_x_given_z=Normal,
                # q(z|h(x)):一元正态分布
                q_z_given_x=Normal,
                # p(x|h(z))的隐藏网络:mean、std,由p(x|z)隐藏网络输入获得
                h_for_p_x=Lambda(partial(wrap_params_net,
                                         h_for_dist=hidden_net_p_x_z,
                                         mean_layer=partial(tf.layers.dense,
                                                            units=x_dims,
                                                            name='x_mean'),
                                         std_layer=partial(softplus_std,
                                                           units=x_dims,
                                                           epsilon=std_epsilon,
                                                           name='x_std')),
                                 name='p_x_given_z'),
                # q(z|h(x))的隐藏网络:mean、std,由q(z|x)隐藏网络输入获得
                h_for_q_z=Lambda(partial(wrap_params_net,
                                         h_for_dist=hidden_net_q_z_x,
                                         mean_layer=partial(tf.layers.dense,
                                                            units=z_dims,
                                                            name='z_mean'),
                                         std_layer=partial(softplus_std,
                                                           units=z_dims,
                                                           epsilon=std_epsilon,
                                                           name='z_std')),
                                 name='q_z_given_x'))
        self._x_dims = x_dims
        self._z_dims = z_dims
示例#3
0
 def validate_size_tuple(n, s):
     if is_integer(s):
         # Do not change a single integer into a tuple!
         # This is because we do not know the dimensionality of the
         # convolution operation here, so we cannot build the size
         # tuple with correct number of elements from the integer notation.
         return int(s)
     return validate_int_tuple_arg(n, s)
示例#4
0
    def __init__(self,
                 h_for_p_x,
                 h_for_q_z,
                 x_dims,
                 z_dims,
                 std_epsilon=1e-4,
                 name=None,
                 scope=None):
        if not is_integer(x_dims) or x_dims <= 0:
            raise ValueError('`x_dims` must be a positive integer')
        if not is_integer(z_dims) or z_dims <= 0:
            raise ValueError('`z_dims` must be a positive integer')

        super(Donut, self).__init__(name=name, scope=scope)
        with reopen_variable_scope(self.variable_scope):
            self._vae = VAE(
                p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])),
                p_x_given_z=Normal,
                q_z_given_x=Normal,
                h_for_p_x=Sequential([
                    h_for_p_x,
                    DictMapper(
                        {
                            'mean':
                            K.layers.Dense(x_dims),
                            'std':
                            lambda x: (std_epsilon + K.layers.Dense(
                                x_dims, activation=tf.nn.softplus)(x))
                        },
                        name='p_x_given_z')
                ]),
                h_for_q_z=Sequential([
                    h_for_q_z,
                    DictMapper(
                        {
                            'mean':
                            K.layers.Dense(z_dims),
                            'std':
                            lambda z: (std_epsilon + K.layers.Dense(
                                z_dims, activation=tf.nn.softplus)(z))
                        },
                        name='q_z_given_x')
                ]),
            )
        self._x_dims = x_dims
        self._z_dims = z_dims
示例#5
0
 def test_is_integer(self):
     if six.PY2:
         self.assertTrue(is_integer(long(1)))
     self.assertTrue(is_integer(int(1)))
     for dtype in [np.int, np.int8, np.int16, np.int32, np.int64,
                   np.uint, np.uint8, np.uint16, np.uint32, np.uint64]:
         v = np.asarray([1], dtype=dtype)[0]
         self.assertTrue(
             is_integer(v),
             msg='%r should be interpreted as integer.' % (v,)
         )
     self.assertFalse(is_integer(np.asarray(0, dtype=np.int)))
     for v in [float(1.0), '', object(), None, True, (), {}, []]:
         self.assertFalse(
             is_integer(v),
             msg='%r should not be interpreted as integer.' % (v,)
         )
示例#6
0
    def _build_input_spec(self, input):
        super(FeatureMappingFlow, self)._build_input_spec(input)

        dtype = input.dtype.base_dtype
        shape = get_static_shape(input)

        # These facts should have been checked in `BaseFlow.build`.
        assert (shape is not None)
        assert (len(shape) >= self.value_ndims)

        # validate the feature axis, ensure it is covered by `value_ndims`.
        axis = self._axis
        axis_is_int = is_integer(axis)
        if axis_is_int:
            axis = [axis]
        else:
            axis = list(axis)

        for i, a in enumerate(axis):
            if a < 0:
                a += len(shape)
            if a < 0 or a < len(shape) - self.value_ndims:
                raise ValueError('`axis` out of range, or not covered by '
                                 '`value_ndims`: axis {}, value_ndims {}, '
                                 'input {}'.format(self._axis,
                                                   self.value_ndims, input))
            if shape[a] is None:
                raise ValueError('The feature axis of `input` is not '
                                 'deterministic: input {}, axis {}'.format(
                                     input, self._axis))

            # Store the negative axis, such that when new inputs can have more
            # dimensions than this `input`, the axis can still be correctly
            # resolved.
            axis[i] = a - len(shape)

        if axis_is_int:
            assert (len(axis) == 1)
            self._axis = axis[0]
        else:
            axis_len = len(axis)
            axis = tuple(sorted(set(axis)))
            if len(axis) != axis_len:
                raise ValueError(
                    'Duplicated elements after resolving negative '
                    '`axis` with respect to the `input`: '
                    'input {}, axis {}'.format(input, self._axis))
            self._axis = tuple(axis)

        # re-build the input spec
        batch_ndims = int(self.require_batch_dims)
        shape_spec = ['...'] + ['?'] * (self.value_ndims + batch_ndims)
        for a in axis:
            shape_spec[a] = shape[a]
        self._y_input_spec = self._x_input_spec = InputSpec(shape=shape_spec,
                                                            dtype=dtype)
        self._x_input_spec.validate('input', input)
示例#7
0
    def __init__(self, h_for_p_x, h_for_q_z, x_dims, z_dims, std_epsilon=1e-4,
                 name=None, scope=None) -> object:
        if not is_integer(x_dims) or x_dims <= 0:
            raise ValueError('`x_dims` must be a positive integer')
        if not is_integer(z_dims) or z_dims <= 0:
            raise ValueError('`z_dims` must be a positive integer')

        super(Donut, self).__init__(name=name, scope=scope)
        with reopen_variable_scope(self.variable_scope):
            self._vae = VAE(
                p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])),
                p_x_given_z=Normal,
                q_z_given_x=Normal,
                h_for_p_x=Lambda(
                    partial(
                        wrap_params_net,
                        h_for_dist=h_for_p_x,
                        mean_layer=partial(
                            tf.layers.dense, units=x_dims, name='x_mean'
                        ),
                        std_layer=partial(
                            softplus_std, units=x_dims, epsilon=std_epsilon,
                            name='x_std'
                        )
                    ),
                    name='p_x_given_z'
                ),
                h_for_q_z=Lambda(
                    partial(
                        wrap_params_net,
                        h_for_dist=h_for_q_z,
                        mean_layer=partial(
                            tf.layers.dense, units=z_dims, name='z_mean'
                        ),
                        std_layer=partial(
                            softplus_std, units=z_dims, epsilon=std_epsilon,
                            name='z_std'
                        )
                    ),
                    name='q_z_given_x'
                )
            )
        self._x_dims = x_dims
        self._z_dims = z_dims
示例#8
0
    def log_prob(self, x, group_event_ndims=None, name=None):
        """Compute the log-probability of `x` against the distribution.

        If `group_event_ndims` is configured, then the likelihoods of
        one group of events will be summed together.

        Parameters
        ----------
        x : tf.Tensor
            The samples to be tested.

        group_event_ndims : int | tf.Tensor
            If specified, will override the attribute `group_event_ndims`
            of this distribution object.

        name : str
            Optional name of this operation.

        Returns
        -------
        tf.Tensor
            The log-probability of `x`.
        """
        with tf.name_scope(name, default_name='log_prob'):
            # determine the number of group event dimensions
            if group_event_ndims is None:
                group_event_ndims = self.group_event_ndims

            # compute the log-likelihood
            log_prob = self._log_prob(x)
            log_prob_shape = log_prob.get_shape()
            try:
                tf.broadcast_static_shape(log_prob_shape,
                                          self.static_batch_shape)
            except ValueError:
                raise RuntimeError(
                    'The shape of computed log-prob does not match '
                    '`batch_shape`, which could be a bug in the distribution '
                    'implementation. (%r vs %r)' %
                    (log_prob_shape, self.static_batch_shape))

            # reduce the dimensions of group event
            def f(ndims):
                return tf.reduce_sum(log_prob, axis=tf.range(-ndims, 0))

            if group_event_ndims is not None:
                if is_integer(group_event_ndims):
                    if group_event_ndims > 0:
                        log_prob = f(group_event_ndims)
                else:
                    group_event_ndims = tf.convert_to_tensor(group_event_ndims,
                                                             dtype=tf.int32)
                    return tf.cond(group_event_ndims > 0,
                                   lambda: f(group_event_ndims),
                                   lambda: log_prob)
            return log_prob
示例#9
0
    def pool2d_ans(pool_fn, input, pool_size, padding, strides):
        """Produce the expected answer of ?_pool2d."""
        strides = (strides, ) * 2 if is_integer(strides) else tuple(strides)
        strides = (1, ) + strides + (1, )
        ksize = (pool_size, ) * 2 if is_integer(pool_size) else tuple(
            pool_size)
        ksize = (1, ) + ksize + (1, )

        session = tf.get_default_session()
        input, s1, s2 = flatten_to_ndims(input, 4)
        padding = padding.upper()

        output = pool_fn(
            value=input,
            ksize=ksize,
            strides=strides,
            padding=padding,
            data_format='NHWC',
        )

        output = unflatten_from_ndims(output, s1, s2)
        output = session.run(output)
        return output
示例#10
0
def validate_strides_or_kernel_size(arg_name, arg_value):
    """
    Validate the `strides` or `filter` arg, to ensure it is a tuple of
    two integers.

    Args:
        arg_name (str): The name of the argument, for formatting error.
        arg_value: The value of the argument.

    Returns:
        (int, int): The validated argument.
    """

    if not is_integer(arg_value) and (not isinstance(arg_value, tuple)
                                      or len(arg_value) != 2
                                      or not is_integer(arg_value[0])
                                      or not is_integer(arg_value[1])):
        raise TypeError(
            '`{}` must be a int or a tuple (int, int).'.format(arg_name))
    if not isinstance(arg_value, tuple):
        arg_value = (arg_value, arg_value)
    arg_value = tuple(int(v) for v in arg_value)
    return arg_value
示例#11
0
 def test_is_float(self):
     float_types = [float, np.float, np.float16, np.float32, np.float64]
     for extra_type in ['float8', 'float128', 'float256']:
         if hasattr(np, extra_type):
             float_types.append(getattr(np, extra_type))
     for dtype in float_types:
         v = np.asarray([1], dtype=dtype)[0]
         self.assertTrue(
             is_float(v),
             msg='{!r} should be interpreted as float'.format(v))
     self.assertFalse(is_integer(np.asarray(0., dtype=np.float32)))
     for v in [int(1), '', object(), None, True, (), {}, []]:
         self.assertFalse(
             is_float(v),
             msg='{!r} should not be interpreted as float'.format(v))
示例#12
0
    def conv2d_ans(input,
                   padding,
                   kernel,
                   bias,
                   strides,
                   dilations,
                   activation_fn=None,
                   normalizer_fn=None,
                   gated=False,
                   gate_sigmoid_bias=2.):
        """Produce the expected answer of conv2d."""
        strides = (strides, ) * 2 if is_integer(strides) else tuple(strides)
        strides = (1, ) + strides + (1, )

        session = tf.get_default_session()
        input, s1, s2 = flatten_to_ndims(input, 4)
        padding = padding.upper()

        if dilations > 1:
            assert (not any(i > 1 for i in strides))
            output = tf.nn.atrous_conv2d(value=input,
                                         filters=kernel,
                                         rate=dilations,
                                         padding=padding)
        else:
            output = tf.nn.conv2d(input=input,
                                  filter=kernel,
                                  strides=strides,
                                  padding=padding,
                                  data_format='NHWC',
                                  dilations=[1] * 4)
        if bias is not None:
            output += bias
        if normalizer_fn:
            output = normalizer_fn(output)
        if gated:
            output, gate = tf.split(output, 2, axis=-1)
        if activation_fn:
            output = activation_fn(output)
        if gated:
            output = output * tf.sigmoid(gate + gate_sigmoid_bias)

        output = unflatten_from_ndims(output, s1, s2)
        output = session.run(output)
        return output
示例#13
0
    def check(self, x, padding, kernel, bias, strides):
        """Integrated tests for specific argument combinations."""
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            rtol=1e-5,
                                            atol=1e-5)
        strides = (strides, ) * 2 if is_integer(strides) else tuple(strides)

        x_shape = (x.shape[-3], x.shape[-2])
        x_channels = x.shape[-1]
        kernel_size = kernel.shape[0], kernel.shape[1]

        # compute the input for the deconv
        y = Conv2dTestCase.conv2d_ans(x, padding, kernel, None, strides, 1)
        y_shape = (y.shape[-3], y.shape[-2])
        y_channels = y.shape[-1]

        # test explicit output_shape, NHWC
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   x_shape,
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=True,
                                                   use_bias=False)
        self.assertEqual(deconv_out.shape, x.shape)

        # memorize the linear output for later tests
        linear_out = np.copy(deconv_out)

        # test explicit output_shape, NCHW
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   x_shape,
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=False,
                                                   use_bias=False)
        assert_allclose(deconv_out, linear_out)

        # test explicit dynamic output_shape, NHWC
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   tf.constant(x_shape),
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=True,
                                                   use_bias=False)
        assert_allclose(deconv_out, linear_out)

        # test explicit dynamic output_shape, NCHW
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   tf.constant(x_shape),
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=False,
                                                   use_bias=False)
        assert_allclose(deconv_out, linear_out)

        # test dynamic input, explicit dynamic output_shape, NHWC
        ph = tf.placeholder(dtype=tf.float32,
                            shape=(None, ) * (len(y.shape) - 3) +
                            (None, None, y_channels))
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   tf.constant(x_shape),
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=True,
                                                   ph=ph,
                                                   use_bias=False)
        assert_allclose(deconv_out, linear_out)

        # test dynamic input, explicit dynamic output_shape, NCHW
        ph = tf.placeholder(dtype=tf.float32,
                            shape=(None, ) * (len(y.shape) - 3) +
                            (y_channels, None, None))
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   tf.constant(x_shape),
                                                   padding,
                                                   kernel,
                                                   None,
                                                   strides,
                                                   channels_last=False,
                                                   ph=ph,
                                                   use_bias=False)
        assert_allclose(deconv_out, linear_out)

        # if the given payload shape matches the auto-inferred shape
        # further test not giving explicit output_shape
        def axis_matches(i):
            return x_shape[i] == get_deconv_output_length(
                y_shape[i], kernel_size[i], strides[i], padding)

        if all(axis_matches(i) for i in (0, 1)):
            # test static input, implicit output_shape, NHWC
            deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                       x_channels,
                                                       kernel_size,
                                                       None,
                                                       padding,
                                                       kernel,
                                                       None,
                                                       strides,
                                                       channels_last=True,
                                                       use_bias=False)
            assert_allclose(deconv_out, linear_out)

            # test static input, implicit output_shape, NCHW
            deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                       x_channels,
                                                       kernel_size,
                                                       None,
                                                       padding,
                                                       kernel,
                                                       None,
                                                       strides,
                                                       channels_last=False,
                                                       use_bias=False)
            assert_allclose(deconv_out, linear_out)

            # test dynamic input, implicit output_shape, NHWC
            ph = tf.placeholder(dtype=tf.float32,
                                shape=(None, ) * (len(y.shape) - 3) +
                                (None, None, y_channels))
            deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                       x_channels,
                                                       kernel_size,
                                                       None,
                                                       padding,
                                                       kernel,
                                                       None,
                                                       strides,
                                                       channels_last=True,
                                                       ph=ph,
                                                       use_bias=False)
            assert_allclose(deconv_out, linear_out)

            # test dynamic input, implicit output_shape, NCHW
            ph = tf.placeholder(dtype=tf.float32,
                                shape=(None, ) * (len(y.shape) - 3) +
                                (y_channels, None, None))
            deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                       x_channels,
                                                       kernel_size,
                                                       None,
                                                       padding,
                                                       kernel,
                                                       None,
                                                       strides,
                                                       channels_last=False,
                                                       ph=ph,
                                                       use_bias=False)
            assert_allclose(deconv_out, linear_out)

        # test normalization and activation
        activation_fn = lambda x: x * 2. + 1.
        normalizer_fn = lambda x: x * 1.5 - 3.
        ans = activation_fn(normalizer_fn(linear_out))

        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   x_shape,
                                                   padding,
                                                   kernel,
                                                   bias,
                                                   strides,
                                                   channels_last=True,
                                                   normalizer_fn=normalizer_fn,
                                                   activation_fn=activation_fn)
        assert_allclose(deconv_out, ans)

        # test normalization and activation and force using bias
        ans = activation_fn(normalizer_fn(linear_out + bias))
        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels,
                                                   kernel_size,
                                                   x_shape,
                                                   padding,
                                                   kernel,
                                                   bias,
                                                   strides,
                                                   channels_last=False,
                                                   use_bias=True,
                                                   normalizer_fn=normalizer_fn,
                                                   activation_fn=activation_fn)
        assert_allclose(deconv_out, ans)

        # test weight norm
        normalized_kernel = l2_normalize(kernel, axis=(0, 1, 2))
        ans = Deconv2dTestCase.run_deconv2d(y,
                                            x_channels,
                                            kernel_size,
                                            x_shape,
                                            padding,
                                            normalized_kernel,
                                            None,
                                            strides,
                                            channels_last=True,
                                            use_bias=False)
        deconv_out = Deconv2dTestCase.run_deconv2d(
            y,
            x_channels,
            kernel_size,
            x_shape,
            padding,
            kernel,
            None,
            strides,
            channels_last=False,
            use_bias=False,
            weight_norm=True,
            # this can force not using scale in weight_norm
            normalizer_fn=(lambda x: x))
        assert_allclose(deconv_out, ans)

        # test gated
        activation_fn = lambda x: x * 2. + 1.
        normalizer_fn = lambda x: x * 1.5 - 3.
        output, gate = np.split(normalizer_fn(linear_out), 2, axis=-1)
        ans = activation_fn(output) * safe_sigmoid(gate + 1.1)

        deconv_out = Deconv2dTestCase.run_deconv2d(y,
                                                   x_channels // 2,
                                                   kernel_size,
                                                   x_shape,
                                                   padding,
                                                   kernel,
                                                   bias,
                                                   strides,
                                                   channels_last=True,
                                                   normalizer_fn=normalizer_fn,
                                                   activation_fn=activation_fn,
                                                   gated=True,
                                                   gate_sigmoid_bias=1.1)
        assert_allclose(deconv_out, ans)
示例#14
0
 def has_non_unit_item(x):
     if is_integer(x):
         return x != 1
     else:
         return any(i != 1 for i in x)