def broadcast_to_shape_strict(x, shape, name=None): """ Broadcast `x` to match `shape`. This method requires `rank(x)` to be less than or equal to `len(shape)`. You may use :func:`broadcast_to_shape` instead, to allow the cases where ``rank(x) > len(shape)``. Args: x: A tensor. shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape. Returns: tf.Tensor: The broadcasted tensor. """ # check the parameters x = tf.convert_to_tensor(x) x_shape = get_static_shape(x) ns_values = [x] if is_tensor_object(shape): shape = tf.convert_to_tensor(shape) ns_values.append(shape) else: shape = tuple(int(s) for s in shape) with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values): cannot_broadcast_msg = ( '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'. format(x, shape)) # assert ``rank(x) <= len(shape)`` if isinstance(shape, tuple) and x_shape is not None: if len(x_shape) > len(shape): raise ValueError(cannot_broadcast_msg) elif isinstance(shape, tuple): with assert_deps([ tf.assert_less_equal(tf.rank(x), len(shape), message=cannot_broadcast_msg) ]) as asserted: if asserted: # pragma: no cover x = tf.identity(x) else: with assert_deps( [assert_rank(shape, 1, message=cannot_broadcast_msg)]) as asserted: if asserted: # pragma: no cover shape = tf.identity(shape) with assert_deps([ tf.assert_less_equal(tf.rank(x), tf.size(shape), message=cannot_broadcast_msg) ]) as asserted: if asserted: # pragma: no cover x = tf.identity(x) # do broadcast return broadcast_to_shape(x, shape)
def broadcast_log_det_against_input(log_det, input, value_ndims, name=None): """ Broadcast the shape of `log_det` to match the shape of `input`. Args: log_det: Tensor, the log-determinant. input: Tensor, the input. value_ndims (int): The number of dimensions of each values sample. Returns: tf.Tensor: The broadcasted log-determinant. """ log_det = tf.convert_to_tensor(log_det) input = tf.convert_to_tensor(input) value_ndims = int(value_ndims) with tf.name_scope(name or 'broadcast_log_det_to_input_shape', values=[log_det, input]): shape = get_shape(input) if value_ndims > 0: err_msg = ( 'Cannot broadcast `log_det` against `input`: log_det is {}, ' 'input is {}, value_ndims is {}.'.format( log_det, input, value_ndims)) with assert_deps( [assert_rank_at_least(input, value_ndims, message=err_msg)]): shape = shape[:-value_ndims] return broadcast_to_shape_strict(log_det, shape)
def multiply_branch(): with assert_deps(assertions): ones_template = tf.ones(shape, dtype=x.dtype.base_dtype) try: return x * ones_template except ValueError: # pragma: no cover raise ValueError(cannot_broadcast_msg)
def transform(self, x, compute_y=True, compute_log_det=True, name=None): """ Transform `x` into `y`, and compute the log-determinant of `f` at `x` (i.e., :math:`\\log \\det \\frac{\\partial f(x)}{\\partial x}`). Args: x (Tensor): The samples of `x`. compute_y (bool): Whether or not to compute :math:`y = f(x)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `y` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_y` and `compute_log_det` are set to :obj:`False`. """ if not compute_y and not compute_log_det: raise ValueError('At least one of `compute_y` and ' '`compute_log_det` should be True.') x = tf.convert_to_tensor(x) if not self._has_built: self.build(x) x = self._x_input_spec.validate('x', x) with tf.name_scope(name, default_name=get_default_scope_name( 'transform', self), values=[x]): y, log_det = self._transform(x, compute_y, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=x, value_ndims=self.x_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) if y is not None: maybe_add_histogram(y, 'y') y = maybe_check_numerics(y, 'y') if log_det is not None: maybe_add_histogram(log_det, 'log_det') log_det = maybe_check_numerics(log_det, 'log_det') return y, log_det
def _check_scale_or_shift_shape(self, name, tensor, x2): assert_op = assert_shape_equal( tensor, x2, message='`{}.shape` expected to be {}, but got {}'.format( name, get_static_shape(x2), get_static_shape(tensor))) with assert_deps([assert_op]) as asserted: if asserted: # pragma: no cover tensor = tf.identity(tensor) return tensor
def __init__(self, distribution, ndims): """ Construct a new :class:`BatchToValueDistribution`. Args: distribution (Distribution): The source distribution. ndims (int): The last few `batch_ndims` to be converted into `value_ndims`. Must be non-negative. """ distribution = as_distribution(distribution) ndims = int(ndims) if ndims < 0: raise ValueError('`ndims` must be non-negative integers: ' 'got {!r}'.format(ndims)) with tf.name_scope('BatchToValueDistribution.init'): # get new batch shape batch_shape = distribution.batch_shape batch_static_shape = distribution.get_batch_shape() if ndims > 0: # static shape if batch_static_shape.ndims < ndims: raise ValueError( '`distribution.batch_shape.ndims` is less then `ndims`' ': distribution {}, batch_shape.ndims {}, ndims {}'. format(distribution, batch_static_shape.ndims, ndims)) batch_static_shape = batch_static_shape[:-ndims] # dynamic shape batch_shape = batch_shape[:-ndims] with assert_deps([ tf.assert_greater_equal( tf.size(distribution.batch_shape), ndims) ]) as asserted: if asserted: # pragma: no cover batch_shape = tf.identity(batch_shape) # get new value ndims value_ndims = ndims + distribution.value_ndims self._distribution = distribution self._ndims = ndims super(BatchToValueDistribution, self).__init__( dtype=distribution.dtype, is_continuous=distribution.is_continuous, is_reparameterized=distribution.is_reparameterized, batch_shape=batch_shape, batch_static_shape=batch_static_shape, value_ndims=value_ndims, )
def assert_batch_shape(c, batch_shape): c_batch_shape = c.batch_shape with assert_deps([ tf.assert_equal( tf.reduce_all( tf.equal( tf.concat([batch_shape, c_batch_shape], 0), tf.concat([c_batch_shape, batch_shape], 0))), True) ]) as asserted: if asserted: # pragma: no cover batch_shape = tf.identity(batch_shape) return batch_shape
def _build(self, input=None): shape = get_static_shape(input) dtype = input.dtype.base_dtype # resolve the split axis split_axis = self._split_axis if split_axis < 0: split_axis += len(shape) if split_axis < 0 or split_axis < len(shape) - self.x_value_ndims: raise ValueError( '`split_axis` out of range, or not covered by `x_value_ndims`: ' 'split_axis {}, x_value_ndims {}, input {}'.format( self._split_axis, self.x_value_ndims, input)) split_axis -= len(shape) err_msg = ( 'The split axis of `input` must be at least 2: input {}, axis {}.'. format(input, split_axis)) if shape[split_axis] is not None: x_n_features = shape[split_axis] if x_n_features < 2: raise ValueError(err_msg) x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left else: x_n_features = get_shape(input)[split_axis] x_n_left = x_n_features // 2 x_n_right = x_n_features - x_n_left with assert_deps( [tf.assert_greater_equal(x_n_features, 2, message=err_msg)]) as asserted: if asserted: # pragma: no cover x_n_left = tf.identity(x_n_left) x_n_right = tf.identity(x_n_right) x_n_features = None self._split_axis = split_axis self._x_n_left = x_n_left self._x_n_right = x_n_right # build the x spec shape_spec = ['?'] * self.x_value_ndims if x_n_features is not None: shape_spec[self._split_axis] = x_n_features assert (not self.require_batch_dims) shape_spec = ['...'] + shape_spec self._x_input_spec = InputSpec(shape=shape_spec, dtype=dtype)
def inverse_transform(self, y, compute_x=True, compute_log_det=True, name=None): """ Transform `y` into `x`, and compute the log-determinant of `f^{-1}` at `y` (i.e., :math:`\\log \\det \\frac{\\partial f^{-1}(y)}{\\partial y}`). Args: y (Tensor): The samples of `y`. compute_x (bool): Whether or not to compute :math:`x = f^{-1}(y)`? Default :obj:`True`. compute_log_det (bool): Whether or not to compute the log-determinant? Default :obj:`True`. name (str): If specified, will use this name as the TensorFlow operational name scope. Returns: (tf.Tensor, tf.Tensor): `x` and the (maybe summed) log-determinant. The items in the returned tuple might be :obj:`None` if corresponding `compute_?` argument is set to :obj:`False`. Raises: RuntimeError: If both `compute_x` and `compute_log_det` are set to :obj:`False`. RuntimeError: If the flow is not explicitly invertible. """ if not self.explicitly_invertible: raise RuntimeError( 'The flow is not explicitly invertible: {!r}'.format(self)) if not compute_x and not compute_log_det: raise ValueError('At least one of `compute_x` and ' '`compute_log_det` should be True.') if not self._has_built: raise RuntimeError('`inverse_transform` cannot be called before ' 'the flow has been built; it can be built by ' 'calling `build`, `apply` or `transform`: ' '{!r}'.format(self)) y = tf.convert_to_tensor(y) y = self._y_input_spec.validate('y', y) with tf.name_scope(name, default_name=get_default_scope_name( 'inverse_transform', self), values=[y]): x, log_det = self._inverse_transform(y, compute_x, compute_log_det) if compute_log_det: with assert_deps([ assert_log_det_shape_matches_input( log_det=log_det, input=y, value_ndims=self.y_value_ndims) ]) as asserted: if asserted: # pragma: no cover log_det = tf.identity(log_det) return x, log_det
def deconv2d(input, out_channels, kernel_size, strides=(1, 1), padding='same', channels_last=True, output_shape=None, activation_fn=None, normalizer_fn=None, weight_norm=False, gated=False, gate_sigmoid_bias=2., kernel=None, kernel_initializer=None, kernel_regularizer=None, kernel_constraint=None, use_bias=None, bias=None, bias_initializer=tf.zeros_initializer(), bias_regularizer=None, bias_constraint=None, trainable=True, name=None, scope=None): """ 2D deconvolutional layer. Args: input (Tensor): The input tensor, at least 4-d. out_channels (int): The channel numbers of the deconvolution output. kernel_size (int or (int, int)): Kernel size over spatial dimensions. strides (int or (int, int)): Strides over spatial dimensions. padding: One of {"valid", "same"}, case in-sensitive. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") output_shape: If specified, use this as the shape of the deconvolution output; otherwise compute the size of each dimension by:: output_size = input_size * strides if padding == 'valid': output_size += max(kernel_size - strides, 0) activation_fn: The activation function. normalizer_fn: The normalizer function. weight_norm (bool or (tf.Tensor) -> tf.Tensor)): If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on `kernel`. `use_scale` will be :obj:`True` if `normalizer_fn` is not specified, and :obj:`False` otherwise. The axis reduction will be determined by the layer. If it is a callable function, then it will be used to normalize the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`. The user must ensure the axis reduction is correct by themselves. gated (bool): Whether or not to use gate on output? `output = activation_fn(output) * sigmoid(gate)`. gate_sigmoid_bias (Tensor): The bias added to `gate` before applying the `sigmoid` activation. kernel (Tensor): Instead of creating a new variable, use this tensor. kernel_initializer: The initializer for `kernel`. Would be ``default_kernel_initializer(...)`` if not specified. kernel_regularizer: The regularizer for `kernel`. kernel_constraint: The constraint for `kernel`. use_bias (bool or None): Whether or not to use `bias`? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. bias (Tensor): Instead of creating a new variable, use this tensor. bias_initializer: The initializer for `bias`. bias_regularizer: The regularizer for `bias`. bias_constraint: The constraint for `bias`. trainable (bool): Whether or not the parameters are trainable? Returns: tf.Tensor: The output tensor. """ input, in_channels, data_format = \ validate_conv2d_input(input, channels_last) out_channels = validate_positive_int_arg('out_channels', out_channels) dtype = input.dtype.base_dtype if gated: out_channels *= 2 # check functional arguments padding = validate_enum_arg('padding', str(padding).upper(), ['VALID', 'SAME']) strides = validate_conv2d_strides_tuple('strides', strides, channels_last) weight_norm_fn = validate_weight_norm_arg(weight_norm, axis=-1, use_scale=normalizer_fn is None) if use_bias is None: use_bias = normalizer_fn is None # get the specification of outputs and parameters kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size) kernel_shape = kernel_size + (out_channels, in_channels) bias_shape = (out_channels, ) given_h, given_w = None, None given_output_shape = output_shape if is_tensor_object(given_output_shape): given_output_shape = tf.convert_to_tensor(given_output_shape) elif given_output_shape is not None: given_h, given_w = given_output_shape # validate the parameters if kernel is not None: kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype) kernel = kernel_spec.validate('kernel', kernel) if kernel_initializer is None: kernel_initializer = default_kernel_initializer(weight_norm) if bias is not None: bias_spec = ParamSpec(shape=bias_shape, dtype=dtype) bias = bias_spec.validate('bias', bias) # the main part of the conv2d layer with tf.variable_scope(scope, default_name=name or 'deconv2d'): with tf.name_scope('output_shape'): # detect the input shape and axis arrangements input_shape = get_static_shape(input) if channels_last: c_axis, h_axis, w_axis = -1, -3, -2 else: c_axis, h_axis, w_axis = -3, -2, -1 output_shape = [None, None, None, None] output_shape[c_axis] = out_channels if given_output_shape is None: if input_shape[h_axis] is not None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if input_shape[w_axis] is not None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: if not is_tensor_object(given_output_shape): output_shape[h_axis] = given_h output_shape[w_axis] = given_w # infer the batch shape in 4-d batch_shape = input_shape[:-3] if None not in batch_shape: output_shape[0] = int(np.prod(batch_shape)) # now the static output shape is ready output_static_shape = tf.TensorShape(output_shape) # prepare for the dynamic batch shape if output_shape[0] is None: output_shape[0] = tf.reduce_prod(get_shape(input)[:-3]) # prepare for the dynamic spatial dimensions if output_shape[h_axis] is None or output_shape[w_axis] is None: if given_output_shape is None: input_shape = get_shape(input) if output_shape[h_axis] is None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if output_shape[w_axis] is None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: assert (is_tensor_object(given_output_shape)) with assert_deps([ assert_rank(given_output_shape, 1), assert_scalar_equal(tf.size(given_output_shape), 2) ]): output_shape[h_axis] = given_output_shape[0] output_shape[w_axis] = given_output_shape[1] # compose the final dynamic shape if any(is_tensor_object(s) for s in output_shape): output_shape = tf.stack(output_shape) else: output_shape = tuple(output_shape) # create the variables if kernel is None: kernel = model_variable('kernel', shape=kernel_shape, dtype=dtype, initializer=kernel_initializer, regularizer=kernel_regularizer, constraint=kernel_constraint, trainable=trainable) if weight_norm_fn is not None: kernel = weight_norm_fn(kernel) maybe_add_histogram(kernel, 'kernel') kernel = maybe_check_numerics(kernel, 'kernel') if use_bias and bias is None: bias = model_variable('bias', shape=bias_shape, initializer=bias_initializer, regularizer=bias_regularizer, constraint=bias_constraint, trainable=trainable) maybe_add_histogram(bias, 'bias') bias = maybe_check_numerics(bias, 'bias') # flatten to 4d output, s1, s2 = flatten_to_ndims(input, 4) # do convolution or deconvolution output = tf.nn.conv2d_transpose(value=output, filter=kernel, output_shape=output_shape, strides=strides, padding=padding, data_format=data_format) if output_static_shape is not None: output.set_shape(output_static_shape) # add bias if use_bias: output = tf.nn.bias_add(output, bias, data_format=data_format) # apply the normalization function if specified if normalizer_fn is not None: output = normalizer_fn(output) # split into halves if gated if gated: output, gate = tf.split(output, 2, axis=c_axis) # apply the activation function if specified if activation_fn is not None: output = activation_fn(output) # apply the gate if required if gated: output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate') # unflatten back to original shape output = unflatten_from_ndims(output, s1, s2) maybe_add_histogram(output, 'output') output = maybe_check_numerics(output, 'output') return output
def _transform(self, x, compute_y, compute_log_det): # split the input x x1, x2 = tf.split(x, [self._x_n_left, self._x_n_right], axis=self._split_axis) do_compute_y = compute_y or (self._y_n_left is None) # apply the left transformation y1, log_det1 = self._left.transform(x1, compute_y=do_compute_y, compute_log_det=compute_log_det) # apply the right transformation if self._right is not None: y2, log_det2 = self._right.transform( x2, compute_y=do_compute_y, compute_log_det=compute_log_det) else: y2, log_det2 = x2, None # check the outputs y1_shape = get_static_shape(y1) y2_shape = get_static_shape(y2) if len(y1_shape) != len(y2_shape): raise RuntimeError('`y_left.ndims` != `y_right.ndims`: y_left {} ' 'vs y_right {}'.format(y1, y2)) # build the y spec if not built join_axis = self._join_axis if self._y_n_left is None: # resolve the join axis if join_axis is None: join_axis = self._split_axis if join_axis < 0: join_axis += len(y1_shape) if join_axis < 0 or join_axis < len(y1_shape) - self.y_value_ndims: raise ValueError( '`join_axis` out of range, or not covered by `y_value_ndims' '`: split_axis {}, y_value_ndims {}, y_left {}, y_right {}' .format(self._split_axis, self.y_value_ndims, y1, y2)) join_axis -= len(y1_shape) err_msg = ( '`y_left.shape[join_axis] + y_right.shape[join_axis]` must ' 'be at least 2: y_left {}, y_right {} axis {}.'.format( y1, y2, join_axis)) y_n_left = y1_shape[join_axis] y_n_right = y2_shape[join_axis] if y_n_left is not None and y_n_right is not None: y_n_features = y_n_left + y_n_right assert (y_n_features >= 2) else: y_n_left = get_shape(y1)[join_axis] y_n_right = get_shape(y2)[join_axis] y_n_features = None with assert_deps([ tf.assert_greater_equal(y_n_left + y_n_right, 2, message=err_msg) ]) as asserted: if asserted: # pragma: no cover y_n_left = tf.identity(y_n_left) y_n_right = tf.identity(y_n_right) self._join_axis = join_axis self._y_n_left = y_n_left self._y_n_right = y_n_right # build the y spec dtype = self._x_input_spec.dtype y_shape_spec = ['?'] * self.y_value_ndims if y_n_features is not None: y_shape_spec[self._split_axis] = y_n_features assert (not self.require_batch_dims) y_shape_spec = ['...'] + y_shape_spec self._y_input_spec = InputSpec(shape=y_shape_spec, dtype=dtype) assert (join_axis is not None) # compute y y = None if compute_y: y = tf.concat([y1, y2], axis=join_axis) # compute log_det log_det = None if compute_log_det: log_det = log_det1 if log_det2 is not None: log_det = sum_log_det([log_det, log_det2]) return y, log_det
def shift(input, shift, name=None): """ Shift each axis of `input` according to `shift`, but keep identical size. The extra content will be discarded if shifted outside the original size. Zeros will be padded to the front or end of shifted axes. Args: input (Tensor): The tensor to be shifted. shift (Iterable[int]): The shift length for each axes. It must be equal to the rank of `input`. For each axis, if its corresponding shift < 0, then the `input` will be shifted to left by `-shift` at that axis. If its shift > 0, then the `input` will be shifted to right by `shift` at that axis. Returns: tf.Tensor: The output tensor. """ shift = tuple(int(s) for s in shift) input = tf.convert_to_tensor(input) shape = get_static_shape(input) if shape is None: raise ValueError( 'The rank of `shape` is required to be deterministic: ' 'got {}'.format(input)) if len(shift) != len(shape): raise ValueError('The length of `shift` is required to equal the rank ' 'of `input`: shift {} vs input {}'.format( shift, input)) # cache for the dynamic shape def get_dynamic_shape(): if cached[0] is None: cached[0] = get_shape(input) return cached[0] cached = [None] # main routine with tf.name_scope(name, default_name='shift', values=[input]): # compute the slicing and padding arguments has_shift = False assert_ops = [] slice_begin = [] slice_size = [] paddings = [] err_msg = ('Cannot shift `input`: input {} vs shift {}'.format( input, shift)) for i, (axis_shift, axis_size) in enumerate(zip(shift, shape)): # fast approach: shift is zero, no slicing at the axis if axis_shift == 0: slice_begin.append(0) slice_size.append(-1) paddings.append((0, 0)) continue # slow approach: shift is not zero, should slice at the axis axis_shift_abs = abs(axis_shift) # we first check whether or not the axis size is big enough if axis_size is None: dynamic_axis_size = get_dynamic_shape()[i] assert_ops.append( tf.assert_greater_equal(dynamic_axis_size, axis_shift_abs, message=err_msg)) else: if axis_size < axis_shift_abs: raise ValueError(err_msg) # next, we compose the slicing range if axis_shift < 0: # shift to left slice_begin.append(-axis_shift) slice_size.append(-1) paddings.append((0, -axis_shift)) else: # shift to right slice_begin.append(0) if axis_size is None: slice_size.append(get_dynamic_shape()[i] - axis_shift) else: slice_size.append(axis_size - axis_shift) paddings.append((axis_shift, 0)) # mark the flag to indicate that we've got any axis to shift has_shift = True if assert_ops: with assert_deps(assert_ops) as asserted: if asserted: input = tf.identity(input) # no axis to shift, directly return the input if not has_shift: return input # do slicing and padding if any(is_tensor_object(s) for s in slice_size): slice_size = tf.stack(slice_size, axis=0) output = tf.slice(input, slice_begin, slice_size) output = tf.pad(output, paddings) return output
def reshape_tail(input, ndims, shape, name=None): """ Reshape the tail (last) `ndims` into specified `shape`. Usage:: x = tf.zeros([2, 3, 4, 5, 6]) reshape_tail(x, 3, [-1]) # output: zeros([2, 3, 120]) reshape_tail(x, 1, [3, 2]) # output: zeros([2, 3, 4, 5, 3, 2]) Args: input (Tensor): The input tensor, at least `ndims` dimensions. ndims (int): To reshape this number of dimensions at tail. shape (Iterable[int] or tf.Tensor): The shape of the new tail. Returns: tf.Tensor: The reshaped tensor. """ input = tf.convert_to_tensor(input) if not is_tensor_object(shape): shape = list(int(s) for s in shape) neg_one_count = 0 for s in shape: if s <= 0: if s == -1: if neg_one_count > 0: raise ValueError('`shape` is not a valid shape: at ' 'most one `-1` can be specified.') else: neg_one_count += 1 else: raise ValueError('`shape` is not a valid shape: {} is ' 'not allowed.'.format(s)) with tf.name_scope(name or 'reshape_tail', values=[input]): # assert the dimension with assert_deps([ assert_rank_at_least( input, ndims, message='rank(input) must be at least ndims') ]) as asserted: if asserted: # pragma: no cover input = tf.identity(input) # compute the static shape static_input_shape = get_static_shape(input) static_output_shape = None if static_input_shape is not None: if ndims > 0: left_shape = static_input_shape[:-ndims] right_shape = static_input_shape[-ndims:] else: left_shape = static_input_shape right_shape = () # attempt to resolve "-1" in `shape` if isinstance(shape, list): if None not in right_shape: shape_size = int(np.prod([s for s in shape if s != -1])) right_shape_size = int(np.prod(right_shape)) if (-1 not in shape and shape_size != right_shape_size) or \ (-1 in shape and right_shape_size % shape_size != 0): raise ValueError( 'Cannot reshape the tail dimensions of ' '`input` into `shape`: input {!r}, ndims ' '{}, shape {}.'.format(input, ndims, shape)) if -1 in shape: pos = shape.index(-1) shape[pos] = right_shape_size // shape_size static_output_shape = left_shape + \ tuple(s if s != -1 else None for s in shape) static_output_shape = tf.TensorShape(static_output_shape) # compute the dynamic shape input_shape = get_shape(input) if ndims > 0: output_shape = concat_shapes([input_shape[:-ndims], shape]) else: output_shape = concat_shapes([input_shape, shape]) # do reshape output = tf.reshape(input, output_shape) output.set_shape(static_output_shape) return output
def broadcast_concat(x, y, axis, name=None): """ Broadcast `x` and `y`, then concat them along `axis`. This method cannot deal with all possible situations yet. `x` and `y` must have known number of dimensions, and only the deterministic axes will be broadcasted. You must ensure the non-deterministic axes are properly broadcasted by yourself. Args: x: The tensor `x`. y: The tensor `y`. axis: The axis to be concatenated. Returns: tf.Tensor: The broadcast and concatenated tensor. """ x = tf.convert_to_tensor(x) y = tf.convert_to_tensor(y) # check the arguments x_static_shape = get_static_shape(x) if x_static_shape is None: raise ValueError('`x` with non-deterministic shape is not supported.') y_static_shape = get_static_shape(y) if y_static_shape is None: raise ValueError('`y` with non-deterministic shape is not supported.') x_rank = len(x_static_shape) y_rank = len(y_static_shape) out_ndims = max(x_rank, y_rank) min_axis = -out_ndims max_axis = out_ndims - 1 if axis < min_axis or axis > max_axis: raise ValueError('Invalid axis: must >= {} and <= {}, got {}'.format( min_axis, max_axis, axis)) if axis >= 0: axis = axis - out_ndims # compute the broadcast shape out_static_shape = [None] * out_ndims x_tile = [1] * out_ndims y_tile = [1] * out_ndims assertions = [] dynamic_shape_cache = {} def get_dynamic_shape(t): if t not in dynamic_shape_cache: dynamic_shape_cache[t] = get_shape(t) return dynamic_shape_cache[t] def broadcast_axis(i, a, b, a_tile, b_tile, a_tensor, b_tensor): err_msg = ('`x` and `y` cannot be broadcast concat: {} vs {}'.format( x, y)) # validate whether or not a == b or can be broadcasted if a is None and b is None: # both dynamic, must be equal a = get_dynamic_shape(a_tensor)[i] b = get_dynamic_shape(b_tensor)[i] assertions.append(tf.assert_equal(a, b, message=err_msg)) elif a is not None and b is not None: # both static, check immediately if a != 1 and b != 1 and a != b: raise ValueError(err_msg) if a == 1: a_tile[i] = b elif b == 1: b_tile[i] = a out_static_shape[i] = max(a, b) elif a is None: # a dynamic, b can be 1 or equal to a a = get_dynamic_shape(a_tensor)[i] if b == 1: b_tile[i] = a else: assertions.append(tf.assert_equal(a, b, message=err_msg)) out_static_shape[i] = b else: broadcast_axis(i, b, a, b_tile, a_tile, b_tensor, a_tensor) def maybe_prepend_dims(t, rank, name): if rank < out_ndims: t = prepend_dims(t, out_ndims - rank, name=name) return t def maybe_tile(t, tile, name): if any(s != 1 for s in tile): if any(is_tensor_object(s) for s in tile): tile = tf.stack(tile, axis=0) t = tf.tile(t, tile, name=name) return t with tf.name_scope(name, default_name='broadcast_concat', values=[x, y]): # infer the configurations for i in range(-1, -out_ndims - 1, -1): a = x_static_shape[i] if i >= -x_rank else 1 b = y_static_shape[i] if i >= -y_rank else 1 if i != axis: broadcast_axis(i, a, b, x_tile, y_tile, x, y) else: if a is not None and b is not None: out_static_shape[i] = a + b # do broadcast x = maybe_tile(maybe_prepend_dims(x, x_rank, name='prepend_dims_to_x'), x_tile, name='tile_x') y = maybe_tile(maybe_prepend_dims(y, y_rank, name='prepend_dims_to_y'), y_tile, name='tile_y') with assert_deps(assertions) as asserted: if asserted: x = tf.identity(x) y = tf.identity(y) # do concat ret = tf.concat([x, y], axis=axis) ret.set_shape(tf.TensorShape(out_static_shape)) return ret
def identity_branch(): with assert_deps(assertions) as asserted: if asserted: return tf.identity(x) else: # pragma: no cover return x
def broadcast_to_shape(x, shape, name=None): """ Broadcast `x` to match `shape`. If ``rank(x) > len(shape)``, only the tail dimensions will be broadcasted to match `shape`. Args: x: A tensor. shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape. Returns: tf.Tensor: The broadcasted tensor. """ # check the parameters x = tf.convert_to_tensor(x) x_shape = get_static_shape(x) ns_values = [x] if is_tensor_object(shape): shape = tf.convert_to_tensor(shape) ns_values.append(shape) else: shape = tuple(int(s) for s in shape) with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values): cannot_broadcast_msg = ( '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'. format(x, shape)) # fast routine: shape is tuple[int] and x_shape is all known, # we can use reshape + tile to do the broadcast, which should be faster # than using ``x * ones(shape)``. if isinstance(shape, tuple) and x_shape is not None and \ all(s is not None for s in x_shape): # reshape to have the same dimension if len(x_shape) < len(shape): x_shape = (1, ) * (len(shape) - len(x_shape)) + x_shape x = tf.reshape(x, x_shape) # tile to have the same shape tile = [] i = -1 while i > -len(shape) - 1: a, b = x_shape[i], shape[i] if a == 1 and b > 1: tile.append(b) elif a != b: raise ValueError(cannot_broadcast_msg) else: tile.append(1) i -= 1 tile = [1] * (len(x_shape) - len(shape)) + list(reversed(tile)) if any(s > 1 for s in tile): x = tf.tile(x, tile) return x # slow routine: we may need ``x * ones(shape)`` to do the broadcast assertions = [] post_assert_shape = False static_shape = tf.TensorShape(None) if isinstance(shape, tuple) and x_shape is not None: need_multiply_ones = False # it should always broadcast if len(x_shape) < len(shape) if len(x_shape) < len(shape): need_multiply_ones = True # check the consistency of x and shape static_shape_hint = [] # list to gather the static shape hint axis_to_check = [] # list to gather the axis to check i = -1 while i >= -len(shape) and i >= -len(x_shape): a, b = x_shape[i], shape[i] if a is None: axis_to_check.append(i) else: if a != b: if a == 1: need_multiply_ones = True else: raise ValueError(cannot_broadcast_msg) static_shape_hint.append(b) i -= 1 # compose the static shape hint if len(shape) < len(x_shape): static_shape = x_shape[:-len(shape)] elif len(shape) > len(x_shape): static_shape = shape[:-len(x_shape)] else: static_shape = () static_shape = tf.TensorShape(static_shape + tuple(reversed(static_shape_hint))) # compose the assertion operations and the multiply flag if axis_to_check: need_multiply_flags = [] x_dynamic_shape = tf.shape(x) for i in axis_to_check: assertions.append( tf.assert_equal(tf.logical_or( tf.equal(x_dynamic_shape[i], shape[i]), tf.equal(x_dynamic_shape[i], 1), ), True, message=cannot_broadcast_msg)) if len(x_shape) >= len(shape): need_multiply_flags.append( tf.not_equal(x_dynamic_shape[i], shape[i])) if not need_multiply_ones: need_multiply_ones = \ tf.reduce_any(tf.stack(need_multiply_flags)) else: # we have no ideal about what `shape` is here, thus we need to # assert the shape after ``x * ones(shape)``. need_multiply_ones = True post_assert_shape = True # do broadcast if `x_shape` != `shape` def multiply_branch(): with assert_deps(assertions): ones_template = tf.ones(shape, dtype=x.dtype.base_dtype) try: return x * ones_template except ValueError: # pragma: no cover raise ValueError(cannot_broadcast_msg) def identity_branch(): with assert_deps(assertions) as asserted: if asserted: return tf.identity(x) else: # pragma: no cover return x t = smart_cond(need_multiply_ones, multiply_branch, identity_branch) t.set_shape(static_shape) if post_assert_shape: post_assert_op = tf.assert_equal(tf.reduce_all( tf.equal(tf.shape(t)[-tf.size(shape):], shape)), True, message=cannot_broadcast_msg) with assert_deps([post_assert_op]) as asserted: if asserted: t = tf.identity(t) return t