def prepend_dims(x, ndims=1, name=None): """ Prepend `[1] * ndims` to the beginning of the shape of `x`. Args: x: The tensor `x`. ndims: Number of `1` to prepend. Returns: tf.Tensor: The tensor with prepended dimensions. """ ndims = int(ndims) if ndims < 0: raise ValueError('`ndims` must be >= 0: got {}'.format(ndims)) x = tf.convert_to_tensor(x) if ndims == 0: return x with tf.name_scope(name, default_name='prepend_dims', values=[x]): static_shape = get_static_shape(x) if static_shape is not None: static_shape = tf.TensorShape([1] * ndims + list(static_shape)) dynamic_shape = concat_shapes([[1] * ndims, get_shape(x)]) y = tf.reshape(x, dynamic_shape) if static_shape is not None: y.set_shape(static_shape) return y
def broadcast_log_det_against_input(log_det, input, value_ndims, name=None): """ Broadcast the shape of `log_det` to match the shape of `input`. Args: log_det: Tensor, the log-determinant. input: Tensor, the input. value_ndims (int): The number of dimensions of each values sample. Returns: tf.Tensor: The broadcasted log-determinant. """ log_det = tf.convert_to_tensor(log_det) input = tf.convert_to_tensor(input) value_ndims = int(value_ndims) with tf.name_scope(name or 'broadcast_log_det_to_input_shape', values=[log_det, input]): shape = get_shape(input) if value_ndims > 0: err_msg = ( 'Cannot broadcast `log_det` against `input`: log_det is {}, ' 'input is {}, value_ndims is {}.'.format( log_det, input, value_ndims)) with assert_deps( [assert_rank_at_least(input, value_ndims, message=err_msg)]): shape = shape[:-value_ndims] return broadcast_to_shape_strict(log_det, shape)
def _build(self, input=None): shape = get_static_shape(input) dtype = input.dtype.base_dtype assert (shape is not None and len(shape) >= self.x_value_ndims) # re-build the x input spec x_shape_spec = [] if self.x_value_ndims > 0: x_shape_spec = list(shape)[-self.x_value_ndims:] if self.require_batch_dims: x_shape_spec = ['?'] + x_shape_spec x_shape_spec = ['...'] + x_shape_spec self._x_input_spec = InputSpec(shape=x_shape_spec, dtype=dtype) # infer the dynamic value shape of x, and store it for inverse transform x_value_shape = [] if self.x_value_ndims > 0: x_value_shape = list(shape)[-self.x_value_ndims:] neg_one_count = 0 for i, s in enumerate(x_value_shape): if s is None: if neg_one_count > 0: x_value_shape = get_shape(input) if self.x_value_ndims > 0: x_value_shape = x_value_shape[-self.x_value_ndims:] break else: x_value_shape[i] = -1 neg_one_count += 1 self._x_value_shape = x_value_shape # now infer the y value shape according to new info obtained from x y_value_shape = list(self._y_value_shape) if isinstance(x_value_shape, list) and -1 not in x_value_shape: x_value_size = int(np.prod(x_value_shape)) y_value_size = int(np.prod([s for s in y_value_shape if s != -1])) if (-1 in y_value_shape and x_value_size % y_value_size != 0) or \ (-1 not in y_value_shape and x_value_size != y_value_size): raise ValueError( 'Cannot reshape the tail dimensions of `x` into `y`: ' 'x value shape {!r}, y value shape {!r}.'.format( x_value_shape, y_value_shape)) if -1 in y_value_shape: y_value_shape[y_value_shape.index(-1)] = \ x_value_size // y_value_size assert (-1 not in y_value_shape) self._y_value_shape = tuple(y_value_shape) # re-build the y input spec y_shape_spec = list(y_value_shape) if self.require_batch_dims: y_shape_spec = ['?'] + y_shape_spec y_shape_spec = ['...'] + y_shape_spec self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype)
def _build(self, input=None): shape = get_static_shape(input) dtype = input.dtype.base_dtype # resolve the split axis split_axis = self._split_axis if split_axis < 0: split_axis += len(shape) if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims: raise ValueError( '`split_axis` out of range, or not covered by `x_value_ndims`: ' 'split_axis {}, x_value_ndims {}, input {}'.format( self._split_axis, self.x_value_ndims, input)) split_axis -= len(shape) err_msg = ( 'The split axis of `input` must be at least 2: input {}, axis {}.'. format(input, split_axis)) if shape[split_axis] is not None: x_n_features = shape[split_axis] if x_n_features < 2: raise ValueError(err_msg) x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left else: x_n_features = get_shape(input)[split_axis] x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left with assert_deps( [tf.assert_greater_equal(x_n_features, 2, message=err_msg)]) as asserted: if asserted: # pragma: no cover x_n_left = tf.identity(x_n_left) x_n_right = tf.identity(x_n_right) x_n_features = None self._split_axis = split_axis self._x_n_left = x_n_left self._x_n_right = x_n_right # build the x spec shape_spec = ['?'] * self.x_value_ndims if x_n_features is not None: shape_spec[self._split_axis] = x_n_features assert (not self.require_batch_dims) shape_spec = ['...'] + shape_spec self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
def _transform(self, x, compute_y, compute_log_det): # compute y y = None if compute_y: y = space_to_depth(x, block_size=self._block_size, channels_last=self._channels_last) # compute log_det log_det = None if compute_log_det: log_det = ZeroLogDet(shape=get_shape(x)[:-3], dtype=x.dtype.base_dtype) return y, log_det
def _inverse_transform(self, y, compute_x, compute_log_det): # compute x x = None if compute_x: x = depth_to_space(y, block_size=self._block_size, channels_last=self._channels_last) # compute log_det log_det = None if compute_log_det: log_det = ZeroLogDet(shape=get_shape(y)[:-3], dtype=y.dtype.base_dtype) return x, log_det
def sample(self, n_samples=None, group_ndims=0, is_reparameterized=None, compute_density=None, name=None): self._validate_sample_is_reparameterized_arg(is_reparameterized) ####################################################################### # slow routine: generate the mixture by one_hot * stack([c.sample()]) # ####################################################################### with tf.name_scope(name or 'Mixture.sample'): cat = self.categorical.sample(n_samples, group_ndims=0) mask = tf.one_hot(cat, self.n_components, dtype=self.dtype, axis=-1) if self.value_ndims > 0: static_shape = (mask.get_shape().as_list() + [1] * self.value_ndims) dynamic_shape = concat_shapes( [get_shape(mask), [1] * self.value_ndims]) mask = tf.reshape(mask, dynamic_shape) mask.set_shape(static_shape) mask = tf.stop_gradient(mask) # derive the mixture samples c_samples = [ c.sample(n_samples, group_ndims=0) for c in self.components ] samples = tf.reduce_sum( mask * tf.stack(c_samples, axis=-self.value_ndims - 1), axis=-self.value_ndims - 1) if not self.is_reparameterized: samples = tf.stop_gradient(samples) t = StochasticTensor(distribution=self, tensor=samples, n_samples=n_samples, group_ndims=group_ndims, is_reparameterized=is_reparameterized) if compute_density: compute_density_immediately(t) return t
def _inverse_transform(self, y, compute_x, compute_log_det): assert (len(get_static_shape(y)) >= self.y_value_ndims) # compute y x = None if compute_x: x = reshape_tail(y, self.y_value_ndims, self._x_value_shape) # compute log_det log_det = None if compute_log_det: dst_shape = get_shape(y) if self.y_value_ndims > 0: dst_shape = dst_shape[:-self.y_value_ndims] log_det = ZeroLogDet(dst_shape, dtype=y.dtype.base_dtype) return x, log_det
def _transform_or_inverse_transform(self, x, compute_y, compute_log_det, permutation): assert (0 > self.axis >= -self.value_ndims >= -len(get_static_shape(x))) assert (get_static_shape(x)[self.axis] == self._n_features) # compute y y = None if compute_y: y = tf.gather(x, permutation, axis=self.axis) # compute log_det log_det = None if compute_log_det: log_det = ZeroLogDet(get_shape(x)[:-self.value_ndims], x.dtype.base_dtype) return y, log_det
def _transform(self, x, compute_y, compute_log_det): assert (len(get_static_shape(x)) >= self.x_value_ndims) # compute y y = None if compute_y: y = reshape_tail(x, self.x_value_ndims, self._y_value_shape) # compute log_det log_det = None if compute_log_det: dst_shape = get_shape(x) if self.x_value_ndims > 0: dst_shape = dst_shape[:-self.x_value_ndims] log_det = ZeroLogDet(dst_shape, dtype=x.dtype.base_dtype) return y, log_det
def dropout(input, rate=.5, noise_shape=None, training=False, name=None): """ Apply dropout on `input`. Args: input (Tensor): The input tensor. rate (float or tf.Tensor): The rate of dropout. noise_shape (tuple[int] or tf.Tensor): Shape of the noise. If not specified, use the shape of `input`. training (bool or tf.Tensor): Whether or not the model is under training stage? Returns: tf.Tensor: The dropout transformed tensor. """ input = tf.convert_to_tensor(input) with tf.name_scope(name, default_name='dropout', values=[input]): dtype = input.dtype.base_dtype retain_prob = convert_to_tensor_and_cast(1. - rate, dtype=dtype) inv_retain_prob = 1. / retain_prob if noise_shape is None: noise_shape = get_shape(input) def training_branch(): noise = tf.random_uniform(shape=noise_shape, minval=0., maxval=1., dtype=dtype) mask = tf.cast(noise < retain_prob, dtype=dtype) return input * mask * inv_retain_prob def testing_branch(): return input return smart_cond( training, training_branch, testing_branch, )
def pixelcnn_2d_input(input, channels_last=True, auxiliary_channel=True, name=None): """ Prepare the input for a PixelCNN 2D network (Tim Salimans, 2017). This method must be applied on the input once before any other PixelCNN 2D layers, for example:: input = ... # the input x # prepare for the convolution stack output = spt.layers.pixelcnn_2d_input(input) # apply the PixelCNN 2D layers. for i in range(5): output = spt.layers.pixelcnn_conv2d_resnet( output, out_channels=64, vertical_kernel_size=(2, 3), horizontal_kernel_size=(2, 2), activation_fn=tf.nn.leaky_relu, normalizer_fn=spt.layers.batch_norm ) # get the final output of the PixelCNN 2D network. output = pixelcnn_2d_output(output) Args: input (Tensor): The input tensor, at least 4-d. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") auxiliary_channel (bool): Whether or not to add a channel to `input`, with all elements set to `1`? Returns: PixelCNN2DOutput: The PixelCNN layer output. """ input, in_channels, _ = validate_conv2d_input(input, channels_last) if channels_last: h_axis, w_axis, c_axis = -3, -2, -1 else: c_axis, h_axis, w_axis = -3, -2, -1 rank = len(get_static_shape(input)) with tf.name_scope(name, default_name='pixelcnn_input', values=[input]): # add a channels with all `1`s if auxiliary_channel: ones_static_shape = [None] * rank ones_dynamic_shape = list(ones_static_shape) ones_dynamic_shape[c_axis] = 1 if None in ones_dynamic_shape: x_dynamic_shape = get_shape(input) for i, s in enumerate(ones_dynamic_shape): if s is None: ones_dynamic_shape[i] = x_dynamic_shape[i] ones = tf.ones(shape=tf.stack(ones_dynamic_shape, axis=0), dtype=input.dtype.base_dtype) ones.set_shape(tf.TensorShape(ones_static_shape)) input = tf.concat([input, ones], axis=c_axis, name='auxiliary_input') # derive the vertical and horizontal convolution stacks down_shift = [0] * rank down_shift[h_axis] = 1 right_shift = [0] * rank right_shift[w_axis] = 1 return PixelCNN2DOutput(vertical=shift(input, shift=down_shift, name='vertical'), horizontal=shift(input, shift=right_shift, name='horizontal'))
def get_dynamic_shape(t): if t not in dynamic_shape_cache: dynamic_shape_cache[t] = get_shape(t) return dynamic_shape_cache[t]
def reshape_tail(input, ndims, shape, name=None): """ Reshape the tail (last) `ndims` into specified `shape`. Usage:: x = tf.zeros([2, 3, 4, 5, 6]) reshape_tail(x, 3, [-1]) # output: zeros([2, 3, 120]) reshape_tail(x, 1, [3, 2]) # output: zeros([2, 3, 4, 5, 3, 2]) Args: input (Tensor): The input tensor, at least `ndims` dimensions. ndims (int): To reshape this number of dimensions at tail. shape (Iterable[int] or tf.Tensor): The shape of the new tail. Returns: tf.Tensor: The reshaped tensor. """ input = tf.convert_to_tensor(input) if not is_tensor_object(shape): shape = list(int(s) for s in shape) neg_one_count = 0 for s in shape: if s <= 0: if s == -1: if neg_one_count > 0: raise ValueError('`shape` is not a valid shape: at ' 'most one `-1` can be specified.') else: neg_one_count += 1 else: raise ValueError('`shape` is not a valid shape: {} is ' 'not allowed.'.format(s)) with tf.name_scope(name or 'reshape_tail', values=[input]): # assert the dimension with assert_deps([ assert_rank_at_least( input, ndims, message='rank(input) must be at least ndims') ]) as asserted: if asserted: # pragma: no cover input = tf.identity(input) # compute the static shape static_input_shape = get_static_shape(input) static_output_shape = None if static_input_shape is not None: if ndims > 0: left_shape = static_input_shape[:-ndims] right_shape = static_input_shape[-ndims:] else: left_shape = static_input_shape right_shape = () # attempt to resolve "-1" in `shape` if isinstance(shape, list): if None not in right_shape: shape_size = int(np.prod([s for s in shape if s != -1])) right_shape_size = int(np.prod(right_shape)) if (-1 not in shape and shape_size != right_shape_size) or \ (-1 in shape and right_shape_size % shape_size != 0): raise ValueError( 'Cannot reshape the tail dimensions of ' '`input` into `shape`: input {!r}, ndims ' '{}, shape {}.'.format(input, ndims, shape)) if -1 in shape: pos = shape.index(-1) shape[pos] = right_shape_size // shape_size static_output_shape = left_shape + \ tuple(s if s != -1 else None for s in shape) static_output_shape = tf.TensorShape(static_output_shape) # compute the dynamic shape input_shape = get_shape(input) if ndims > 0: output_shape = concat_shapes([input_shape[:-ndims], shape]) else: output_shape = concat_shapes([input_shape, shape]) # do reshape output = tf.reshape(input, output_shape) output.set_shape(static_output_shape) return output
def is_log_det_shape_matches_input(log_det, input, value_ndims, name=None): """ Check whether or not the shape of `log_det` matches the shape of `input`. Basically, the shapes of `log_det` and `input` should satisfy:: if value_ndims > 0: assert(log_det.shape == input.shape[:-value_ndims]) else: assert(log_det.shape == input.shape) Args: log_det: Tensor, the log-determinant. input: Tensor, the input. value_ndims (int): The number of dimensions of each values sample. Returns: bool or tf.Tensor: A boolean or a tensor, indicating whether or not the shape of `log_det` matches the shape of `input`. """ if not is_tensor_object(log_det): log_det = tf.convert_to_tensor(log_det) if not is_tensor_object(input): input = tf.convert_to_tensor(input) value_ndims = int(value_ndims) with tf.name_scope(name or 'is_log_det_shape_matches_input'): log_det_shape = get_static_shape(log_det) input_shape = get_static_shape(input) # if both shapes have deterministic ndims, we can compare each axis # separately. if log_det_shape is not None and input_shape is not None: if len(log_det_shape) + value_ndims != len(input_shape): return False dynamic_axis = [] for i, (a, b) in enumerate(zip(log_det_shape, input_shape)): if a is None or b is None: dynamic_axis.append(i) elif a != b: return False if not dynamic_axis: return True log_det_shape = get_shape(log_det) input_shape = get_shape(input) return tf.reduce_all([ tf.equal(log_det_shape[i], input_shape[i]) for i in dynamic_axis ]) # otherwise we need to do a fully dynamic check, including check # ``log_det.ndims + value_ndims == input_shape.ndims`` is_ndims_matches = tf.equal( tf.rank(log_det) + value_ndims, tf.rank(input)) log_det_shape = get_shape(log_det) input_shape = get_shape(input) if value_ndims > 0: input_shape = input_shape[:-value_ndims] return tf.cond( is_ndims_matches, lambda: tf.reduce_all( tf.equal( # The following trick ensures we're comparing two tensors # with the same shape, such as to avoid some potential issues # about the cond operation. tf.concat([log_det_shape, input_shape], 0), tf.concat([input_shape, log_det_shape], 0), )), lambda: tf.constant(False, dtype=tf.bool))
def pixelcnn_2d_sample(fn, inputs, height, width, channels_last=True, start=0, end=None, back_prop=False, parallel_iterations=1, swap_memory=False, name=None): """ Sample output from a PixelCNN 2D network, pixel-by-pixel. Args: fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`, the function to derive the outputs of PixelCNN 2D network at iteration `i`. `inputs` are the pixel-by-pixel outputs gathered through iteration `0` to iteration `i - 1`. The iteration index `i` may range from `0` to `height * width - 1`. inputs (Iterable[tf.Tensor]): The initial input tensors. All the tensors must be at least 4-d, with identical shape. height (int or tf.Tensor): The height of the outputs. width (int or tf.Tensor): The width of the outputs. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") start (int or tf.Tensor): The start iteration, default `0`. end (int or tf.Tensor): The end (exclusive) iteration. Default `height * width`. back_prop, parallel_iterations, swap_memory: Arguments passed to :func:`tf.while_loop`. Returns: tuple[tf.Tensor]: The final outputs. """ from tfsnippet.layers.convolutional.utils import validate_conv2d_input # check the arguments def to_int(t): if is_tensor_object(t): return convert_to_tensor_and_cast(t, dtype=tf.int32) return int(t) height = to_int(height) width = to_int(width) inputs = list(inputs) if not inputs: raise ValueError('`inputs` must not be empty.') inputs[0], _, _ = validate_conv2d_input(inputs[0], channels_last=channels_last, arg_name='inputs[0]') input_spec = InputSpec(shape=get_static_shape(inputs[0])) for i, input in enumerate(inputs[1:], 1): inputs[i] = input_spec.validate('inputs[{}]'.format(i), input) # do pixelcnn sampling with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs): # the total size, start and end index total_size = height * width start = convert_to_tensor_and_cast(start, dtype=tf.int32) if end is None: end = convert_to_tensor_and_cast(total_size, dtype=tf.int32) else: end = convert_to_tensor_and_cast(end, dtype=tf.int32) # the mask shape if channels_last: mask_shape = [height, width, 1] else: mask_shape = [height, width] if any(is_tensor_object(t) for t in mask_shape): mask_shape = tf.stack(mask_shape, axis=0) # the input dynamic shape input_shape = get_shape(inputs[0]) # the pixelcnn sampling loop def loop_cond(idx, _): return idx < end def loop_body(idx, inputs): inputs = tuple(inputs) # prepare for the output mask selector = tf.reshape( tf.concat([ tf.ones([idx], dtype=tf.uint8), tf.zeros([1], dtype=tf.uint8), tf.ones([total_size - idx - 1], dtype=tf.uint8) ], axis=0), mask_shape) selector = tf.cast(broadcast_to_shape(selector, input_shape), dtype=tf.bool) # obtain the outputs outputs = list(fn(idx, inputs)) if len(outputs) != len(inputs): raise ValueError( 'The length of outputs != inputs: {} vs {}'.format( len(outputs), len(inputs))) # mask the outputs for i, (input, output) in enumerate(zip(inputs, outputs)): input_dtype = inputs[i].dtype.base_dtype output_dtype = output.dtype.base_dtype if output_dtype != input_dtype: raise TypeError( '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: ' '{output} vs {input}'.format(idx=i, output=output_dtype, input=input_dtype)) outputs[i] = tf.where(selector, input, output) return idx + 1, tuple(outputs) i0 = start _, outputs = tf.while_loop( cond=loop_cond, body=loop_body, loop_vars=(i0, tuple(inputs)), back_prop=back_prop, parallel_iterations=parallel_iterations, swap_memory=swap_memory, ) return outputs
def deconv2d(input, out_channels, kernel_size, strides=(1, 1), padding='same', channels_last=True, output_shape=None, activation_fn=None, normalizer_fn=None, weight_norm=False, gated=False, gate_sigmoid_bias=2., kernel=None, kernel_initializer=None, kernel_regularizer=None, kernel_constraint=None, use_bias=None, bias=None, bias_initializer=tf.zeros_initializer(), bias_regularizer=None, bias_constraint=None, trainable=True, name=None, scope=None): """ 2D deconvolutional layer. Args: input (Tensor): The input tensor, at least 4-d. out_channels (int): The channel numbers of the deconvolution output. kernel_size (int or (int, int)): Kernel size over spatial dimensions. strides (int or (int, int)): Strides over spatial dimensions. padding: One of {"valid", "same"}, case in-sensitive. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") output_shape: If specified, use this as the shape of the deconvolution output; otherwise compute the size of each dimension by:: output_size = input_size * strides if padding == 'valid': output_size += max(kernel_size - strides, 0) activation_fn: The activation function. normalizer_fn: The normalizer function. weight_norm (bool or (tf.Tensor) -> tf.Tensor)): If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on `kernel`. `use_scale` will be :obj:`True` if `normalizer_fn` is not specified, and :obj:`False` otherwise. The axis reduction will be determined by the layer. If it is a callable function, then it will be used to normalize the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`. The user must ensure the axis reduction is correct by themselves. gated (bool): Whether or not to use gate on output? `output = activation_fn(output) * sigmoid(gate)`. gate_sigmoid_bias (Tensor): The bias added to `gate` before applying the `sigmoid` activation. kernel (Tensor): Instead of creating a new variable, use this tensor. kernel_initializer: The initializer for `kernel`. Would be ``default_kernel_initializer(...)`` if not specified. kernel_regularizer: The regularizer for `kernel`. kernel_constraint: The constraint for `kernel`. use_bias (bool or None): Whether or not to use `bias`? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. bias (Tensor): Instead of creating a new variable, use this tensor. bias_initializer: The initializer for `bias`. bias_regularizer: The regularizer for `bias`. bias_constraint: The constraint for `bias`. trainable (bool): Whether or not the parameters are trainable? Returns: tf.Tensor: The output tensor. """ input, in_channels, data_format = \ validate_conv2d_input(input, channels_last) out_channels = validate_positive_int_arg('out_channels', out_channels) dtype = input.dtype.base_dtype if gated: out_channels *= 2 # check functional arguments padding = validate_enum_arg('padding', str(padding).upper(), ['VALID', 'SAME']) strides = validate_conv2d_strides_tuple('strides', strides, channels_last) weight_norm_fn = validate_weight_norm_arg(weight_norm, axis=-1, use_scale=normalizer_fn is None) if use_bias is None: use_bias = normalizer_fn is None # get the specification of outputs and parameters kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size) kernel_shape = kernel_size + (out_channels, in_channels) bias_shape = (out_channels, ) given_h, given_w = None, None given_output_shape = output_shape if is_tensor_object(given_output_shape): given_output_shape = tf.convert_to_tensor(given_output_shape) elif given_output_shape is not None: given_h, given_w = given_output_shape # validate the parameters if kernel is not None: kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype) kernel = kernel_spec.validate('kernel', kernel) if kernel_initializer is None: kernel_initializer = default_kernel_initializer(weight_norm) if bias is not None: bias_spec = ParamSpec(shape=bias_shape, dtype=dtype) bias = bias_spec.validate('bias', bias) # the main part of the conv2d layer with tf.variable_scope(scope, default_name=name or 'deconv2d'): with tf.name_scope('output_shape'): # detect the input shape and axis arrangements input_shape = get_static_shape(input) if channels_last: c_axis, h_axis, w_axis = -1, -3, -2 else: c_axis, h_axis, w_axis = -3, -2, -1 output_shape = [None, None, None, None] output_shape[c_axis] = out_channels if given_output_shape is None: if input_shape[h_axis] is not None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if input_shape[w_axis] is not None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: if not is_tensor_object(given_output_shape): output_shape[h_axis] = given_h output_shape[w_axis] = given_w # infer the batch shape in 4-d batch_shape = input_shape[:-3] if None not in batch_shape: output_shape[0] = int(np.prod(batch_shape)) # now the static output shape is ready output_static_shape = tf.TensorShape(output_shape) # prepare for the dynamic batch shape if output_shape[0] is None: output_shape[0] = tf.reduce_prod(get_shape(input)[:-3]) # prepare for the dynamic spatial dimensions if output_shape[h_axis] is None or output_shape[w_axis] is None: if given_output_shape is None: input_shape = get_shape(input) if output_shape[h_axis] is None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if output_shape[w_axis] is None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: assert (is_tensor_object(given_output_shape)) with assert_deps([ assert_rank(given_output_shape, 1), assert_scalar_equal(tf.size(given_output_shape), 2) ]): output_shape[h_axis] = given_output_shape[0] output_shape[w_axis] = given_output_shape[1] # compose the final dynamic shape if any(is_tensor_object(s) for s in output_shape): output_shape = tf.stack(output_shape) else: output_shape = tuple(output_shape) # create the variables if kernel is None: kernel = model_variable('kernel', shape=kernel_shape, dtype=dtype, initializer=kernel_initializer, regularizer=kernel_regularizer, constraint=kernel_constraint, trainable=trainable) if weight_norm_fn is not None: kernel = weight_norm_fn(kernel) maybe_add_histogram(kernel, 'kernel') kernel = maybe_check_numerics(kernel, 'kernel') if use_bias and bias is None: bias = model_variable('bias', shape=bias_shape, initializer=bias_initializer, regularizer=bias_regularizer, constraint=bias_constraint, trainable=trainable) maybe_add_histogram(bias, 'bias') bias = maybe_check_numerics(bias, 'bias') # flatten to 4d output, s1, s2 = flatten_to_ndims(input, 4) # do convolution or deconvolution output = tf.nn.conv2d_transpose(value=output, filter=kernel, output_shape=output_shape, strides=strides, padding=padding, data_format=data_format) if output_static_shape is not None: output.set_shape(output_static_shape) # add bias if use_bias: output = tf.nn.bias_add(output, bias, data_format=data_format) # apply the normalization function if specified if normalizer_fn is not None: output = normalizer_fn(output) # split into halves if gated if gated: output, gate = tf.split(output, 2, axis=c_axis) # apply the activation function if specified if activation_fn is not None: output = activation_fn(output) # apply the gate if required if gated: output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate') # unflatten back to original shape output = unflatten_from_ndims(output, s1, s2) maybe_add_histogram(output, 'output') output = maybe_check_numerics(output, 'output') return output
def log_prob(self, given, group_ndims=0, name=None): given = tf.convert_to_tensor(given) with tf.name_scope('DiscretizedLogistic.log_prob', values=[given]): # inv_scale = 1. / scale inv_scale = maybe_check_numerics( tf.exp(-self.log_scale, name='inv_scale'), 'inv_scale') # half_bin = bin_size / 2 half_bin = self._bin_size * .5 # delta = bin_size / scale, half_delta = delta / 2 half_delta = half_bin * inv_scale # log(delta) = log(bin_size) - log(scale) log_delta = tf.log(self._bin_size) - self.log_scale x_mid = (given - self.mean) * inv_scale x_low = x_mid - half_delta x_high = x_mid + half_delta cdf_low = tf.sigmoid(x_low, name='cdf_low') cdf_high = tf.sigmoid(x_high, name='cdf_high') # the middle bins cases: # log(sigmoid(x_high) - sigmoid(x_low)) # but in extreme cases where `sigmoid(x_high) - sigmoid(x_low)` # is very small, we use an alternative form, as in PixelCNN++. cdf_delta = cdf_high - cdf_low middle_bins_pdf = tf.where( cdf_delta > self._epsilon, # to avoid NaNs pollute the select statement, we have to use # `maximum(cdf_delta, 1e-12)` tf.log(tf.maximum(cdf_delta, 1e-12)), # the alternative form. basically it can be derived by using # the mean value theorem for integration. x_mid + log_delta - 2. * tf.nn.softplus(x_mid) ) log_prob = maybe_check_numerics(middle_bins_pdf, 'middle_bins_pdf') # broadcasted given, shape == x_mid broadcast_given = broadcast_to_shape(given, get_shape(x_mid)) # the left-edge bin case # log(sigmoid(x_high) - sigmoid(-infinity)) if self._biased_edges and self.min_val is not None: left_edge = self._min_val + half_bin left_edge_pdf = maybe_check_numerics( -tf.nn.softplus(-x_high), 'left_edge_pdf') log_prob = tf.where( broadcast_given < left_edge, left_edge_pdf, log_prob) # the right-edge bin case # log(sigmoid(infinity) - sigmoid(x_low)) if self._biased_edges and self.max_val is not None: right_edge = self._max_val - half_bin right_edge_pdf = maybe_check_numerics( -tf.nn.softplus(x_low), 'right_edge_pdf') log_prob = tf.where( broadcast_given >= right_edge, right_edge_pdf, log_prob) # now reduce the group_ndims log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims) return log_prob
def _transform(self, x, compute_y, compute_log_det): # split the input x x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right], axis=self._split_axis) do_compute_y = compute_y or (self._y_n_left is None) # apply the left transformation y1, log_det1 = self._left.transform(x1, compute_y=do_compute_y, compute_log_det=compute_log_det) # apply the right transformation if self._right is not None: y2, log_det2 = self._right.transform( x2, compute_y=do_compute_y, compute_log_det=compute_log_det) else: y2, log_det2 = x2, None # check the outputs y1_shape = get_static_shape(y1) y2_shape = get_static_shape(y2) if len(y1_shape) != len(y2_shape): raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} ' 'vs y_right {}'.format(y1, y2)) # build the y spec if not built join_axis = self._join_axis if self._y_n_left is None: # resolve the join axis if join_axis is None: join_axis = self._split_axis if join_axis < 0: join_axis += len(y1_shape) if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims: raise ValueError( '`join_axis` out of range, or not covered by `y_value_ndims' '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}' .format(self._split_axis, self.y_value_ndims, y1, y2)) join_axis -= len(y1_shape) err_msg = ( '`y_left.shape[join_axis] + y_right.shape[join_axis]` must ' 'be at least 2: y_left {}, y_right {} axis {}.'.format( y1, y2, join_axis)) y_n_left = y1_shape[join_axis] y_n_right = y2_shape[join_axis] if y_n_left is not None and y_n_right is not None: y_n_features = y_n_left + y_n_right assert (y_n_features >= 2) else: y_n_left = get_shape(y1)[join_axis] y_n_right = get_shape(y2)[join_axis] y_n_features = None with assert_deps([ tf.assert_greater_equal(y_n_left + y_n_right, 2, message=err_msg) ]) as asserted: if asserted: # pragma: no cover y_n_left = tf.identity(y_n_left) y_n_right = tf.identity(y_n_right) self._join_axis = join_axis self._y_n_left = y_n_left self._y_n_right = y_n_right # build the y spec dtype = self._x_input_spec.dtype y_shape_spec = ['?'] * self.y_value_ndims if y_n_features is not None: y_shape_spec[self._split_axis] = y_n_features assert (not self.require_batch_dims) y_shape_spec = ['...'] + y_shape_spec self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype) assert (join_axis is not None) # compute y y = None if compute_y: y = tf.concat([y1, y2], axis=join_axis) # compute log_det log_det = None if compute_log_det: log_det = log_det1 if log_det2 is not None: log_det = sum_log_det([log_det, log_det2]) return y, log_det
def get_dynamic_shape(): if cached[0] is None: cached[0] = get_shape(input) return cached[0]
def _transform_or_inverse_transform(self, x, compute_y, compute_log_det, reverse=False): # Since the transform and inverse_transform are too similar, we # just implement these two methods by one super method, controlled # by `reverse == True/False`. # check the argument shape = get_static_shape(x) assert (len(shape) >= self.value_ndims) # checked in `BaseFlow` # split the tensor x1, x2, n2 = self._split(x) # compute the scale and shift shift, pre_scale = self._shift_and_scale_fn(x1, n2) if self._scale_type is not None and pre_scale is None: raise RuntimeError('`scale_type` != None, but no scale is ' 'computed.') elif self._scale_type is None and pre_scale is not None: raise RuntimeError('`scale_type` == None, but scale is computed.') if pre_scale is not None: pre_scale = self._check_scale_or_shift_shape( 'scale', pre_scale, x2) shift = self._check_scale_or_shift_shape('shift', shift, x2) # derive the scale class if self._scale_type == 'sigmoid': scale = SigmoidScale(pre_scale + self._sigmoid_scale_bias, self._epsilon) elif self._scale_type == 'exp': scale = ExpScale(pre_scale, self._epsilon) elif self._scale_type == 'linear': scale = LinearScale(pre_scale, self._epsilon) else: assert (self._scale_type is None) scale = None # compute y y = None if compute_y: y1 = x1 if reverse: y2 = x2 if scale is not None: y2 = y2 / scale y2 -= shift else: y2 = x2 + shift if scale is not None: y2 = y2 * scale y = self._unsplit(y1, y2) # compute log_det log_det = None if compute_log_det: assert (self.value_ndims >= 0) # checked in `_build` if scale is not None: log_det = tf.reduce_sum( scale.neg_log_scale() if reverse else scale.log_scale(), axis=list(range(-self.value_ndims, 0))) else: log_det = ZeroLogDet( get_shape(x)[:-self.value_ndims], x.dtype.base_dtype) return y, log_det
def log_prob(self, given, group_ndims=0, name=None): given = tf.convert_to_tensor(given) with tf.name_scope('DiscretizedLogistic.log_prob', values=[given]): if self.discretize_given: given = self._discretize(given) # inv_scale = 1. / exp(log_scale) inv_scale = maybe_check_numerics( tf.exp(-self.log_scale, name='inv_scale'), 'inv_scale') # half_bin = bin_size / 2 half_bin = self.bin_size * .5 # delta = bin_size / scale, half_delta = delta / 2 half_delta = half_bin * inv_scale # x_mid = (x - mean) / scale x_mid = (given - self.mean) * inv_scale # x_low = (x - mean - bin_size * 0.5) / scale x_low = x_mid - half_delta # x_high = (x - mean + bin_size * 0.5) / scale x_high = x_mid + half_delta cdf_low = tf.sigmoid(x_low, name='cdf_low') cdf_high = tf.sigmoid(x_high, name='cdf_high') cdf_delta = cdf_high - cdf_low # the middle bins cases: # log(sigmoid(x_high) - sigmoid(x_low)) # middle_bins_pdf = tf.log(cdf_delta + self._epsilon) middle_bins_pdf = tf.log(tf.maximum(cdf_delta, self._epsilon)) # with tf.control_dependencies([ # tf.print( # 'x_mid: ', tf.reduce_mean(x_mid), # 'x_low: ', tf.reduce_mean(x_low), # 'x_high: ', tf.reduce_mean(x_high), # 'diff: ', tf.reduce_mean((given - self.mean)), # 'mean: ', tf.reduce_mean(self.mean), # 'scale: ', tf.reduce_mean(tf.exp(self.log_scale)), # 'half_delta: ', tf.reduce_mean(half_delta), # 'cdf_delta: ', tf.reduce_mean(cdf_delta), # 'log_pdf: ', tf.reduce_mean(middle_bins_pdf) # ) # ]): # middle_bins_pdf = tf.identity(middle_bins_pdf) # # but in extreme cases where `sigmoid(x_high) - sigmoid(x_low)` # # is very small, we use an alternative form, as in PixelCNN++. # log_delta = tf.log(self.bin_size) - self.log_scale # middle_bins_pdf = tf.where( # cdf_delta > self._epsilon, # # to avoid NaNs pollute the select statement, we have to use # # `maximum(cdf_delta, 1e-12)` # tf.log(tf.maximum(cdf_delta, 1e-12)), # # the alternative form. basically it can be derived by using # # the mean value theorem for integration. # x_mid + log_delta - 2. * tf.nn.softplus(x_mid) # ) log_prob = maybe_check_numerics(middle_bins_pdf, 'middle_bins_pdf') if self.biased_edges and self.min_val is not None: # broadcasted given, shape == x_mid broadcast_given = broadcast_to_shape(given, get_shape(x_low)) # the left-edge bin case # log(sigmoid(x_high) - sigmoid(-infinity)) left_edge = self.min_val + half_bin left_edge_pdf = maybe_check_numerics(-tf.nn.softplus(-x_high), 'left_edge_pdf') log_prob = tf.where(tf.less(broadcast_given, left_edge), left_edge_pdf, log_prob) # the right-edge bin case # log(sigmoid(infinity) - sigmoid(x_low)) right_edge = self.max_val - half_bin right_edge_pdf = maybe_check_numerics(-tf.nn.softplus(x_low), 'right_edge_pdf') log_prob = tf.where( tf.greater_equal(broadcast_given, right_edge), right_edge_pdf, log_prob) # now reduce the group_ndims log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims) return log_prob