def validate_conv2d_input(input, channels_last, arg_name='input'): """ Validate the input for 2-d convolution. Args: input: The input tensor, must be at least 4-d. channels_last (bool): Whether or not the last dimension is the channels dimension? (i.e., `data_format` is "NHWC") arg_name (str): Name of the input argument. Returns: (tf.Tensor, int, str): The validated input tensor, the number of input channels, and the data format. """ if channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) channel_axis = -1 data_format = 'NHWC' else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) channel_axis = -3 data_format = 'NCHW' input = input_spec.validate(arg_name, input) input_shape = get_static_shape(input) in_channels = input_shape[channel_axis] return input, in_channels, data_format
def _build(self, input=None): shape = get_static_shape(input) dtype = input.dtype.base_dtype assert (shape is not None and len(shape) >= self.x_value_ndims) # re-build the x input spec x_shape_spec = [] if self.x_value_ndims > 0: x_shape_spec = list(shape)[-self.x_value_ndims:] if self.require_batch_dims: x_shape_spec = ['?'] + x_shape_spec x_shape_spec = ['...'] + x_shape_spec self._x_input_spec = InputSpec(shape=x_shape_spec, dtype=dtype) # infer the dynamic value shape of x, and store it for inverse transform x_value_shape = [] if self.x_value_ndims > 0: x_value_shape = list(shape)[-self.x_value_ndims:] neg_one_count = 0 for i, s in enumerate(x_value_shape): if s is None: if neg_one_count > 0: x_value_shape = get_shape(input) if self.x_value_ndims > 0: x_value_shape = x_value_shape[-self.x_value_ndims:] break else: x_value_shape[i] = -1 neg_one_count += 1 self._x_value_shape = x_value_shape # now infer the y value shape according to new info obtained from x y_value_shape = list(self._y_value_shape) if isinstance(x_value_shape, list) and -1 not in x_value_shape: x_value_size = int(np.prod(x_value_shape)) y_value_size = int(np.prod([s for s in y_value_shape if s != -1])) if (-1 in y_value_shape and x_value_size % y_value_size != 0) or \ (-1 not in y_value_shape and x_value_size != y_value_size): raise ValueError( 'Cannot reshape the tail dimensions of `x` into `y`: ' 'x value shape {!r}, y value shape {!r}.'.format( x_value_shape, y_value_shape)) if -1 in y_value_shape: y_value_shape[y_value_shape.index(-1)] = \ x_value_size // y_value_size assert (-1 not in y_value_shape) self._y_value_shape = tuple(y_value_shape) # re-build the y input spec y_shape_spec = list(y_value_shape) if self.require_batch_dims: y_shape_spec = ['?'] + y_shape_spec y_shape_spec = ['...'] + y_shape_spec self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)
def _build_input_spec(self, input): batch_ndims = int(self.require_batch_dims) dtype = input.dtype.base_dtype x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims) y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims) self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype) self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype)
def _build(self, input=None): # check the input. input = tf.convert_to_tensor(input) dtype = input.dtype.base_dtype shape = get_static_shape(input) # These facts should have been checked in `BaseFlow.build`. assert (shape is not None) assert (len(shape) >= self.value_ndims) # compute var spec and input spec min_axis = min(self.axis) shape_spec = [None] * len(shape) for a in self.axis: shape_spec[a] = shape[a] shape_spec = shape_spec[min_axis:] assert (not not shape_spec) assert (self.value_ndims >= len(shape_spec)) self._y_input_spec = self._x_input_spec = InputSpec( shape=(('...', ) + ('?', ) * (self.value_ndims - len(shape_spec)) + tuple(shape_spec)), dtype=dtype) # the shape of variables must only have necessary dimensions, # such that we can switch freely between `channels_last = True` # (in which case `input.shape = (..., *,)`, and `channels_last = False` # (in which case `input.shape = (..., *, 1, 1)`. self._var_shape = tuple(s for s in shape_spec if s is not None) # and we still need to compute the aligned variable shape, such that # we can immediately reshape the variables into this aligned shape, # then compute `scale * input + bias`. self._var_shape_aligned = tuple(s or 1 for s in shape_spec) self._var_spec = ParamSpec(self._var_shape) # validate the input self._x_input_spec.validate('input', input) # build the variables self._bias = model_variable('bias', dtype=dtype, shape=self._var_shape, regularizer=self._bias_regularizer, constraint=self._bias_constraint, trainable=self._trainable) if self._scale_type == 'exp': self._pre_scale = model_variable( 'log_scale', dtype=dtype, shape=self._var_shape, regularizer=self._log_scale_regularizer, constraint=self._log_scale_constraint, trainable=self._trainable) else: self._pre_scale = model_variable( 'scale', dtype=dtype, shape=self._var_shape, regularizer=self._scale_regularizer, constraint=self._scale_constraint, trainable=self._trainable)
def _build_input_spec(self, input): super(FeatureMappingFlow, self)._build_input_spec(input) dtype = input.dtype.base_dtype shape = get_static_shape(input) # These facts should have been checked in `BaseFlow.build`. assert (shape is not None) assert (len(shape) >= self.value_ndims) # validate the feature axis, ensure it is covered by `value_ndims`. axis = self._axis axis_is_int = is_integer(axis) if axis_is_int: axis = [axis] else: axis = list(axis) for i, a in enumerate(axis): if a < 0: a += len(shape) if a < 0 or a < len(shape) - self.value_ndims: raise ValueError('`axis` out of range, or not covered by ' '`value_ndims`: axis {}, value_ndims {}, ' 'input {}'.format(self._axis, self.value_ndims, input)) if shape[a] is None: raise ValueError('The feature axis of `input` is not ' 'deterministic: input {}, axis {}'.format( input, self._axis)) # Store the negative axis, such that when new inputs can have more # dimensions than this `input`, the axis can still be correctly # resolved. axis[i] = a - len(shape) if axis_is_int: assert (len(axis) == 1) self._axis = axis[0] else: axis_len = len(axis) axis = tuple(sorted(set(axis))) if len(axis) != axis_len: raise ValueError( 'Duplicated elements after resolving negative ' '`axis` with respect to the `input`: ' 'input {}, axis {}'.format(input, self._axis)) self._axis = tuple(axis) # re-build the input spec batch_ndims = int(self.require_batch_dims) shape_spec = ['...'] + ['?'] * (self.value_ndims + batch_ndims) for a in axis: shape_spec[a] = shape[a] self._y_input_spec = self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype) self._x_input_spec.validate('input', input)
def transpose_conv2d_axis(input, from_channels_last, to_channels_last, name=None): """ Ensure the channels axis of `input` tensor to be placed at the desired axis. Args: input (tf.Tensor): The input tensor, at least 4-d. from_channels_last (bool): Whether or not the channels axis is the last axis in `input`? (i.e., the data format is "NHWC") to_channels_last (bool): Whether or not the channels axis should be the last axis in the output tensor? Returns: tf.Tensor: The (maybe) transposed output tensor. """ if from_channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) input = input_spec.validate('input', input) input_shape = get_static_shape(input) sample_and_batch_axis = [i for i in range(len(input_shape) - 3)] # check whether or not axis should be transpose if from_channels_last and not to_channels_last: transpose_axis = [-1, -3, -2] elif not from_channels_last and to_channels_last: transpose_axis = [-2, -1, -3] else: transpose_axis = None # transpose the axis if transpose_axis is not None: transpose_axis = [i + len(input_shape) for i in transpose_axis] input = tf.transpose(input, sample_and_batch_axis + transpose_axis, name=name or 'transpose_conv2d_axis') return input
def _build(self, input=None): shape = get_static_shape(input) dtype = input.dtype.base_dtype # resolve the split axis split_axis = self._split_axis if split_axis < 0: split_axis += len(shape) if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims: raise ValueError( '`split_axis` out of range, or not covered by `x_value_ndims`: ' 'split_axis {}, x_value_ndims {}, input {}'.format( self._split_axis, self.x_value_ndims, input)) split_axis -= len(shape) err_msg = ( 'The split axis of `input` must be at least 2: input {}, axis {}.'. format(input, split_axis)) if shape[split_axis] is not None: x_n_features = shape[split_axis] if x_n_features < 2: raise ValueError(err_msg) x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left else: x_n_features = get_shape(input)[split_axis] x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left with assert_deps( [tf.assert_greater_equal(x_n_features, 2, message=err_msg)]) as asserted: if asserted: # pragma: no cover x_n_left = tf.identity(x_n_left) x_n_right = tf.identity(x_n_right) x_n_features = None self._split_axis = split_axis self._x_n_left = x_n_left self._x_n_right = x_n_right # build the x spec shape_spec = ['?'] * self.x_value_ndims if x_n_features is not None: shape_spec[self._split_axis] = x_n_features assert (not self.require_batch_dims) shape_spec = ['...'] + shape_spec self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
def softmax_classification_output(logits, name=None): """ Get the most possible softmax classification output for each logit. Args: logits: The softmax logits. Its last dimension will be treated as the softmax logits dimension, and will be reduced. Returns: tf.Tensor: tf.int32 tensor, the class label for each logit. """ logits = InputSpec(shape=('...', '?', '?')).validate('logits', logits) with tf.name_scope(name, default_name='softmax_classification_output', values=[logits]): return tf.argmax(logits, axis=-1, output_type=tf.int32)
def classification_accuracy(y_pred, y_true, name=None): """ Compute the classification accuracy for `y_pred` and `y_true`. Args: y_pred: The predicted labels. y_true: The ground truth labels. Its shape must match `y_pred`. Returns: tf.Tensor: The accuracy. """ y_pred = tf.convert_to_tensor(y_pred) y_true = InputSpec(shape=get_static_shape(y_pred)). \ validate('y_true', y_true) with tf.name_scope(name, default_name='classification_accuracy', values=[y_pred, y_true]): return tf.reduce_mean( tf.cast(tf.equal(y_pred, y_true), dtype=tf.float32))
class BaseFlow(BaseLayer): """ The basic class for normalizing flows. A normalizing flow transforms a random variable `x` into `y` by an (implicitly) invertible mapping :math:`y = f(x)`, whose Jaccobian matrix determinant :math:`\\det \\frac{\\partial f(x)}{\\partial x} \\neq 0`, thus can derive :math:`\\log p(y)` from given :math:`\\log p(x)`. """ @add_name_and_scope_arg_doc def __init__(self, x_value_ndims, y_value_ndims=None, require_batch_dims=False, name=None, scope=None): """ Construct a new :class:`BaseFlow`. Args: x_value_ndims (int): Number of value dimensions in `x`. `x.ndims - x_value_ndims == log_det.ndims`. y_value_ndims (int): Number of value dimensions in `y`. `y.ndims - y_value_ndims == log_det.ndims`. If not specified, use `x_value_ndims`. require_batch_dims (bool): If :obj:`True`, `x` are required to have at least `x_value_ndims + 1` dimensions, and `y` are required to have at least `y_value_ndims + 1` dimensions. If :obj:`False`, `x` are required to have at least `x_value_ndims` dimensions, and `y` are required to have at least `y_value_ndims` dimensions. """ x_value_ndims = int(x_value_ndims) if y_value_ndims is None: y_value_ndims = x_value_ndims else: y_value_ndims = int(y_value_ndims) super(BaseFlow, self).__init__(name=name, scope=scope) self._x_value_ndims = x_value_ndims self._y_value_ndims = y_value_ndims self._require_batch_dims = bool(require_batch_dims) self._x_input_spec = None # type: InputSpec self._y_input_spec = None # type: InputSpec def invert(self): """ Get the inverted flow from this flow. The :meth:`transform()` will become the :meth:`inverse_transform()` in the inverted flow, and the :meth:`inverse_transform()` will become the :meth:`transform()` in the inverted flow. If the current flow has not been initialized, it must be initialized via :meth:`inverse_transform()` in the new flow. Returns: tfsnippet.layers.InvertFlow: The inverted flow. """ from .invert import InvertFlow return InvertFlow(self) @property def x_value_ndims(self): """ Get the number of value dimensions in `x`. Returns: int: The number of value dimensions in `x`. """ return self._x_value_ndims @property def y_value_ndims(self): """ Get the number of value dimensions in `y`. Returns: int: The number of value dimensions in `y`. """ return self._y_value_ndims @property def require_batch_dims(self): """Whether or not this flow requires batch dimensions.""" return self._require_batch_dims @property def explicitly_invertible(self): """ Whether or not this flow is explicitly invertible? If a flow is not explicitly invertible, then it only supports to transform `x` into `y`, and corresponding :math:`\\log p(x)` into :math:`\\log p(y)`. It cannot compute :math:`\\log p(y)` directly without knowing `x`, nor can it transform `x` back into `y`. Returns: bool: A boolean indicating whether or not the flow is explicitly invertible. """ raise NotImplementedError() def _build_input_spec(self, input): batch_ndims = int(self.require_batch_dims) dtype = input.dtype.base_dtype x_input_shape = ['...'] + ['?'] * (self.x_value_ndims + batch_ndims) y_input_shape = ['...'] + ['?'] * (self.y_value_ndims + batch_ndims) self._x_input_spec = InputSpec(shape=x_input_shape, dtype=dtype) self._y_input_spec = InputSpec(shape=y_input_shape, dtype=dtype) def build(self, input=None): # check the input. if input is None: raise ValueError('`input` is required to build {}.'.format( self.__class__.__name__)) input = tf.convert_to_tensor(input) shape = get_static_shape(input) require_ndims = self.x_value_ndims + int(self.require_batch_dims) require_ndims_text = ('x_value_ndims + 1' if self.require_batch_dims else 'x_value_ndims') if shape is None or len(shape) < require_ndims: raise ValueError('`x.ndims` must be known and >= `{}`: x ' '{} vs ndims `{}`.'.format( require_ndims_text, input, require_ndims)) # build the input spec self._build_input_spec(input) # build the layer return super(BaseFlow, self).build(input) def _transform(self, x, compute_y, compute_log_det): raise NotImplementedError() def transform(self, x, compute_y=True, compute_log_det=True, name=None): """ Transform `x` into `y`, and compute the log-determinant of `f` at `x` (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`). Args: x (Tensor): The samples of `x`. compute_y (bool): Whether or not to compute :math:`y = f(x)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_y` and `compute_log_det` are set to :obj:`False`. """ if not compute_y and not compute_log_det: raise ValueError('At least one of `compute_y` and ' '`compute_log_det` should be True.') x = tf.convert_to_tensor(x) if not self._has_built: self.build(x) x = self._x_input_spec.validate('x', x) with tf.name_scope(name, default_name=get_default_scope_name( 'transform', self), values=[x]): y, log_det = self._transform(x, compute_y, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=x, value_ndims=self.x_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) return y, log_det def _inverse_transform(self, y, compute_x, compute_log_det): raise NotImplementedError() def inverse_transform(self, y, compute_x=True, compute_log_det=True, name=None): """ Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at `y` (i.e., :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`). Args: y (Tensor): The samples of `y`. compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_x` and `compute_log_det` are set to :obj:`False`. RuntimeError: If the flow is not explicitly invertible. """ if not self.explicitly_invertible: raise RuntimeError( 'The flow is not explicitly invertible: {!r}'.format(self)) if not compute_x and not compute_log_det: raise ValueError('At least one of `compute_x` and ' '`compute_log_det` should be True.') if not self._has_built: raise RuntimeError('`inverse_transform` cannot be called before ' 'the flow has been built; it can be built by ' 'calling `build`, `apply` or `transform`: ' '{!r}'.format(self)) y = tf.convert_to_tensor(y) y = self._y_input_spec.validate('y', y) with tf.name_scope(name, default_name=get_default_scope_name( 'inverse_transform', self), values=[y]): x, log_det = self._inverse_transform(y, compute_x, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=y, value_ndims=self.y_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) return x, log_det def _apply(self, x): y, _ = self.transform(x, compute_y=True, compute_log_det=False) return y
def resnet_deconv2d_block(input, out_channels, kernel_size, strides=(1, 1), shortcut_kernel_size=(1, 1), channels_last=True, resize_at_exit=False, activation_fn=None, normalizer_fn=None, weight_norm=False, dropout_fn=None, kernel_initializer=None, kernel_regularizer=None, kernel_constraint=None, use_bias=None, bias_initializer=tf.zeros_initializer(), bias_regularizer=None, bias_constraint=None, trainable=True, name=None, scope=None): """ 2D deconvolutional ResNet block. Args: input (Tensor): The input tensor, at least 4-d. out_channels (int): The channel numbers of the output. kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for "conv" and "conv_1" deconvolutional layers. strides (int or tuple[int]): Strides over spatial dimensions, for all three deconvolutional layers. shortcut_kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for the "shortcut" deconvolutional layer. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") resize_at_exit (bool): See :func:`resnet_general_block`. activation_fn: The activation function. normalizer_fn: The normalizer function. weight_norm: Passed to :func:`deconv2d`. dropout_fn: The dropout function. kernel_initializer: Passed to :func:`deconv2d`. kernel_regularizer: Passed to :func:`deconv2d`. kernel_constraint: Passed to :func:`deconv2d`. use_bias: Whether or not to use `bias` in :func:`deconv2d`? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. bias_initializer: Passed to :func:`deconv2d`. bias_regularizer: Passed to :func:`deconv2d`. bias_constraint: Passed to :func:`deconv2d`. trainable: Passed to :func:`convdeconv2d2d`. Returns: tf.Tensor: The output tensor. See Also: :func:`resnet_general_block` """ # check the input and infer the input shape if channels_last: input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) c_axis = -1 else: input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) c_axis = -3 input = input_spec.validate('input', input) in_channels = get_static_shape(input)[c_axis] # check the functional arguments if use_bias is None: use_bias = normalizer_fn is None # derive the convolution function conv_fn = partial( deconv2d, channels_last=channels_last, weight_norm=weight_norm, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, kernel_constraint=kernel_constraint, use_bias=use_bias, bias_initializer=bias_initializer, bias_regularizer=bias_regularizer, bias_constraint=bias_constraint, trainable=trainable, ) # build the resnet block return resnet_general_block(conv_fn, input=input, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, strides=strides, shortcut_kernel_size=shortcut_kernel_size, resize_at_exit=resize_at_exit, activation_fn=activation_fn, normalizer_fn=normalizer_fn, dropout_fn=dropout_fn, name=name or 'resnet_deconv2d_block', scope=scope)
def pixelcnn_conv2d_resnet( input, out_channels, conv_fn=conv2d, # the default values for the following two arguments # are from the PixelCNN++ paper. vertical_kernel_size=(2, 3), horizontal_kernel_size=(2, 2), strides=(1, 1), channels_last=True, use_shortcut_conv=None, shortcut_conv_fn=None, shortcut_kernel_size=(1, 1), activation_fn=None, normalizer_fn=None, dropout_fn=None, gated=False, gate_sigmoid_bias=2., use_bias=None, name=None, scope=None, **kwargs): """ PixelCNN 2D convolutional ResNet block. Args: input (PixelCNN2DOutput): The output from the previous PixelCNN layer. out_channels (int): The channel numbers of the output. conv_fn: The convolution function for "conv_0" and "conv_1" convolutional layers. See :func:`resnet_general_block`. vertical_kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for "conv_0" and "conv_1" convolutional layers in the PixelCNN vertical stack. horizontal_kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for "conv_0" and "conv_1" convolutional layers in the PixelCNN horizontal stack. strides (int or tuple[int]): Strides over spatial dimensions, for "conv_0", "conv_1" and "shortcut" convolutional layers. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") use_shortcut_conv (True or None): If :obj:`True`, force to apply a linear convolution transformation on the shortcut path. If :obj:`None` (by default), only use shortcut if necessary. shortcut_conv_fn: The convolution function for the "shortcut" convolutional layer. If not specified, use `conv_fn`. shortcut_kernel_size (int or tuple[int]): Kernel size over spatial dimensions, for the "shortcut" convolutional layer. activation_fn: The activation function. normalizer_fn: The normalizer function. dropout_fn: The dropout function. gated (bool): Whether or not to use gate on the output of "conv_1"? `conv_1_output = activation_fn(conv_1_output) * sigmoid(gate)`. gate_sigmoid_bias (Tensor): The bias added to `gate` before applying the `sigmoid` activation. use_bias (bool or None): Whether or not to use `bias` in "conv_0" and "conv_1"? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. \\**kwargs: Other named arguments passed to "conv_0", "conv_1" and "shortcut" convolutional layers. Returns: PixelCNN2DOutput: The PixelCNN layer output. """ if not isinstance(input, PixelCNN2DOutput): raise TypeError('`input` is not an instance of `PixelCNN2DOutput`: ' 'got {!r}'.format(input)) vertical, in_channels, _ = validate_conv2d_input(input.vertical, channels_last, 'input.vertical') horizontal = InputSpec(shape=get_static_shape(vertical), dtype=vertical.dtype.base_dtype). \ validate('input.horizontal', input.horizontal) if shortcut_conv_fn is None: shortcut_conv_fn = conv_fn # derive the convolution functions vertical_conv_fn = partial(shifted_conv2d, conv_fn=conv_fn, spatial_shift=(1, 0)) horizon_conv_fn = partial(shifted_conv2d, conv_fn=conv_fn, spatial_shift=(1, 1)) with tf.variable_scope(scope, default_name=name or 'pixelcnn_conv2d_resnet'): # first, derive the vertical stack output vertical = resnet_general_block( conv_fn=vertical_conv_fn, input=vertical, in_channels=in_channels, out_channels=out_channels, kernel_size=vertical_kernel_size, strides=strides, channels_last=channels_last, use_shortcut_conv=use_shortcut_conv, shortcut_conv_fn=shortcut_conv_fn, shortcut_kernel_size=shortcut_kernel_size, resize_at_exit=False, # always resize at conv_0 after_conv_0=None, after_conv_1=None, activation_fn=activation_fn, normalizer_fn=normalizer_fn, dropout_fn=dropout_fn, gated=gated, gate_sigmoid_bias=gate_sigmoid_bias, use_bias=use_bias, scope='vertical', **kwargs) horizontal = resnet_general_block( conv_fn=horizon_conv_fn, input=horizontal, in_channels=in_channels, out_channels=out_channels, kernel_size=horizontal_kernel_size, strides=strides, channels_last=channels_last, use_shortcut_conv=use_shortcut_conv, shortcut_conv_fn=shortcut_conv_fn, shortcut_kernel_size=shortcut_kernel_size, resize_at_exit=False, # always resize at conv_0 after_conv_0=partial(pixelcnn_conv2d_resnet_after_conv_0, vertical=vertical, out_channels=out_channels, channels_last=channels_last, activation_fn=activation_fn, shortcut_conv_fn=shortcut_conv_fn, **kwargs), after_conv_1=None, activation_fn=activation_fn, normalizer_fn=normalizer_fn, dropout_fn=dropout_fn, gated=gated, gate_sigmoid_bias=gate_sigmoid_bias, use_bias=use_bias, scope='horizontal', **kwargs) return PixelCNN2DOutput(vertical=vertical, horizontal=horizontal)
def _transform(self, x, compute_y, compute_log_det): # split the input x x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right], axis=self._split_axis) do_compute_y = compute_y or (self._y_n_left is None) # apply the left transformation y1, log_det1 = self._left.transform(x1, compute_y=do_compute_y, compute_log_det=compute_log_det) # apply the right transformation if self._right is not None: y2, log_det2 = self._right.transform( x2, compute_y=do_compute_y, compute_log_det=compute_log_det) else: y2, log_det2 = x2, None # check the outputs y1_shape = get_static_shape(y1) y2_shape = get_static_shape(y2) if len(y1_shape) != len(y2_shape): raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} ' 'vs y_right {}'.format(y1, y2)) # build the y spec if not built join_axis = self._join_axis if self._y_n_left is None: # resolve the join axis if join_axis is None: join_axis = self._split_axis if join_axis < 0: join_axis += len(y1_shape) if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims: raise ValueError( '`join_axis` out of range, or not covered by `y_value_ndims' '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}' .format(self._split_axis, self.y_value_ndims, y1, y2)) join_axis -= len(y1_shape) err_msg = ( '`y_left.shape[join_axis] + y_right.shape[join_axis]` must ' 'be at least 2: y_left {}, y_right {} axis {}.'.format( y1, y2, join_axis)) y_n_left = y1_shape[join_axis] y_n_right = y2_shape[join_axis] if y_n_left is not None and y_n_right is not None: y_n_features = y_n_left + y_n_right assert (y_n_features >= 2) else: y_n_left = get_shape(y1)[join_axis] y_n_right = get_shape(y2)[join_axis] y_n_features = None with assert_deps([ tf.assert_greater_equal(y_n_left + y_n_right, 2, message=err_msg) ]) as asserted: if asserted: # pragma: no cover y_n_left = tf.identity(y_n_left) y_n_right = tf.identity(y_n_right) self._join_axis = join_axis self._y_n_left = y_n_left self._y_n_right = y_n_right # build the y spec dtype = self._x_input_spec.dtype y_shape_spec = ['?'] * self.y_value_ndims if y_n_features is not None: y_shape_spec[self._split_axis] = y_n_features assert (not self.require_batch_dims) y_shape_spec = ['...'] + y_shape_spec self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype) assert (join_axis is not None) # compute y y = None if compute_y: y = tf.concat([y1, y2], axis=join_axis) # compute log_det log_det = None if compute_log_det: log_det = log_det1 if log_det2 is not None: log_det = sum_log_det([log_det, log_det2]) return y, log_det
def pixelcnn_2d_sample(fn, inputs, height, width, channels_last=True, start=0, end=None, back_prop=False, parallel_iterations=1, swap_memory=False, name=None): """ Sample output from a PixelCNN 2D network, pixel-by-pixel. Args: fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`, the function to derive the outputs of PixelCNN 2D network at iteration `i`. `inputs` are the pixel-by-pixel outputs gathered through iteration `0` to iteration `i - 1`. The iteration index `i` may range from `0` to `height * width - 1`. inputs (Iterable[tf.Tensor]): The initial input tensors. All the tensors must be at least 4-d, with identical shape. height (int or tf.Tensor): The height of the outputs. width (int or tf.Tensor): The width of the outputs. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") start (int or tf.Tensor): The start iteration, default `0`. end (int or tf.Tensor): The end (exclusive) iteration. Default `height * width`. back_prop, parallel_iterations, swap_memory: Arguments passed to :func:`tf.while_loop`. Returns: tuple[tf.Tensor]: The final outputs. """ from tfsnippet.layers.convolutional.utils import validate_conv2d_input # check the arguments def to_int(t): if is_tensor_object(t): return convert_to_tensor_and_cast(t, dtype=tf.int32) return int(t) height = to_int(height) width = to_int(width) inputs = list(inputs) if not inputs: raise ValueError('`inputs` must not be empty.') inputs[0], _, _ = validate_conv2d_input(inputs[0], channels_last=channels_last, arg_name='inputs[0]') input_spec = InputSpec(shape=get_static_shape(inputs[0])) for i, input in enumerate(inputs[1:], 1): inputs[i] = input_spec.validate('inputs[{}]'.format(i), input) # do pixelcnn sampling with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs): # the total size, start and end index total_size = height * width start = convert_to_tensor_and_cast(start, dtype=tf.int32) if end is None: end = convert_to_tensor_and_cast(total_size, dtype=tf.int32) else: end = convert_to_tensor_and_cast(end, dtype=tf.int32) # the mask shape if channels_last: mask_shape = [height, width, 1] else: mask_shape = [height, width] if any(is_tensor_object(t) for t in mask_shape): mask_shape = tf.stack(mask_shape, axis=0) # the input dynamic shape input_shape = get_shape(inputs[0]) # the pixelcnn sampling loop def loop_cond(idx, _): return idx < end def loop_body(idx, inputs): inputs = tuple(inputs) # prepare for the output mask selector = tf.reshape( tf.concat([ tf.ones([idx], dtype=tf.uint8), tf.zeros([1], dtype=tf.uint8), tf.ones([total_size - idx - 1], dtype=tf.uint8) ], axis=0), mask_shape) selector = tf.cast(broadcast_to_shape(selector, input_shape), dtype=tf.bool) # obtain the outputs outputs = list(fn(idx, inputs)) if len(outputs) != len(inputs): raise ValueError( 'The length of outputs != inputs: {} vs {}'.format( len(outputs), len(inputs))) # mask the outputs for i, (input, output) in enumerate(zip(inputs, outputs)): input_dtype = inputs[i].dtype.base_dtype output_dtype = output.dtype.base_dtype if output_dtype != input_dtype: raise TypeError( '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: ' '{output} vs {input}'.format(idx=i, output=output_dtype, input=input_dtype)) outputs[i] = tf.where(selector, input, output) return idx + 1, tuple(outputs) i0 = start _, outputs = tf.while_loop( cond=loop_cond, body=loop_body, loop_vars=(i0, tuple(inputs)), back_prop=back_prop, parallel_iterations=parallel_iterations, swap_memory=swap_memory, ) return outputs
def act_norm(input, axis=-1, initializing=False, scale_type='exp', bias_regularizer=None, bias_constraint=None, log_scale_regularizer=None, log_scale_constraint=None, scale_regularizer=None, scale_constraint=None, trainable=True, epsilon=1e-6, name=None, scope=None, value_ndims=None): """ ActNorm proposed by (Kingma & Dhariwal, 2018). Examples:: import tfsnippet as spt # apply act_norm on a dense layer x = spt.layers.dense(x, units, activation_fn=tf.nn.relu, normalizer_fn=functools.partial( act_norm, initializing=initializing)) # apply act_norm on a conv2d layer x = spt.layers.conv2d(x, out_channels, (3, 3), channels_last=channels_last, activation_fn=tf.nn.relu, normalizer_fn=functools.partial( act_norm, axis=-1 if channels_last else -3, value_ndims=3, initializing=initializing, )) Args: input (Tensor): The input tensor. axis (int or Iterable[int]): The axis to apply ActNorm. Dimensions not in `axis` will be averaged out when computing the mean of activations. Default `-1`, the last dimension. All items of the `axis` should be covered by `value_ndims`. initializing (bool): Whether or not to use the input `x` to initialize the layer parameters? (default :obj:`True`) scale_type: One of {"exp", "linear"}. If "exp", ``y = (x + bias) * tf.exp(log_scale)``. If "linear", ``y = (x + bias) * scale``. Default is "exp". bias_regularizer: The regularizer for `bias`. bias_constraint: The constraint for `bias`. log_scale_regularizer: The regularizer for `log_scale`. log_scale_constraint: The constraint for `log_scale`. scale_regularizer: The regularizer for `scale`. scale_constraint: The constraint for `scale`. trainable (bool): Whether or not the variables are trainable? epsilon: Small float to avoid dividing by zero or taking logarithm of zero. Returns: tf.Tensor: The output after the ActNorm has been applied. """ input = InputSpec(shape=['...']).validate('input', input) rank = len(get_static_shape(input)) axis = list(validate_int_tuple_arg('axis', axis)) for i, a in enumerate(axis): if a >= 0: axis[i] = a - rank value_ndims = max(-a for a in axis) layer = ActNorm( axis=axis, value_ndims=value_ndims, initialized=not initializing, scale_type=scale_type, bias_regularizer=bias_regularizer, bias_constraint=bias_constraint, log_scale_regularizer=log_scale_regularizer, log_scale_constraint=log_scale_constraint, scale_regularizer=scale_regularizer, scale_constraint=scale_constraint, trainable=trainable, epsilon=epsilon, name=name, scope=scope ) return layer.apply(input)