def assert_log_det_shape_matches_input(log_det, input, value_ndims, name=None): """ Assert the shape of `log_det` matches the shape of `input`. Args: log_det: Tensor, the log-determinant. input: Tensor, the input. value_ndims (int): The number of dimensions of each values sample. Returns: tf.Operation or None: The assertion operation, or None if the assertion can be made statically. """ if not is_tensor_object(log_det): log_det = tf.convert_to_tensor(log_det) if not is_tensor_object(input): input = tf.convert_to_tensor(input) value_ndims = int(value_ndims) with tf.name_scope(name or 'assert_log_det_shape_matches_input'): cmp_result = is_log_det_shape_matches_input(log_det, input, value_ndims) error_message = ( 'The shape of `log_det` does not match the shape of ' '`input`: log_det {!r} vs input {!r}, value_ndims is {!r}'.format( log_det, input, value_ndims)) if cmp_result is False: raise AssertionError(error_message) elif cmp_result is True: return None else: return tf.assert_equal(cmp_result, True, message=error_message)
def validate_n_samples(value, name): """ Validate the `n_samples` argument. Args: value: An int32 value, a int32 :class:`tf.Tensor`, or :obj:`None`. name (str): Name of the argument (in error message). Returns: int or tf.Tensor: The validated `n_samples` argument value. Raises: TypeError or ValueError or None: If the value cannot be validated. """ if is_tensor_object(value): @contextlib.contextmanager def mkcontext(): with tf.name_scope('validate_n_samples'): yield else: @contextlib.contextmanager def mkcontext(): yield if value is not None: with mkcontext(): validator = TensorArgValidator(name=name) value = validator.require_positive(validator.require_int32(value)) return value
def broadcast_to_shape_strict(x, shape, name=None): """ Broadcast `x` to match `shape`. This method requires `rank(x)` to be less than or equal to `len(shape)`. You may use :func:`broadcast_to_shape` instead, to allow the cases where ``rank(x) > len(shape)``. Args: x: A tensor. shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape. Returns: tf.Tensor: The broadcasted tensor. """ # check the parameters x = tf.convert_to_tensor(x) x_shape = get_static_shape(x) ns_values = [x] if is_tensor_object(shape): shape = tf.convert_to_tensor(shape) ns_values.append(shape) else: shape = tuple(int(s) for s in shape) with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values): cannot_broadcast_msg = ( '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'. format(x, shape)) # assert ``rank(x) <= len(shape)`` if isinstance(shape, tuple) and x_shape is not None: if len(x_shape) > len(shape): raise ValueError(cannot_broadcast_msg) elif isinstance(shape, tuple): with assert_deps([ tf.assert_less_equal(tf.rank(x), len(shape), message=cannot_broadcast_msg) ]) as asserted: if asserted: # pragma: no cover x = tf.identity(x) else: with assert_deps( [assert_rank(shape, 1, message=cannot_broadcast_msg)]) as asserted: if asserted: # pragma: no cover shape = tf.identity(shape) with assert_deps([ tf.assert_less_equal(tf.rank(x), tf.size(shape), message=cannot_broadcast_msg) ]) as asserted: if asserted: # pragma: no cover x = tf.identity(x) # do broadcast return broadcast_to_shape(x, shape)
def test_is_tensor_object(self): for obj in [ tf.constant(0.), # type: tf.Tensor tf.get_variable('x', dtype=tf.float32, shape=()), TensorWrapper(), StochasticTensor(Mock(is_reparameterized=False), tf.constant(0.)) ]: self.assertTrue( is_tensor_object(obj), msg='{!r} should be interpreted as a tensor object'.format( obj)) for obj in [1, '', object(), None, True, (), {}, [], np.zeros([1])]: self.assertFalse( is_tensor_object(obj), msg='{!r} should not be interpreted as a tensor object'.format( obj))
def assert_scalar_equal(a, b, message=None, name=None): """ Assert 0-d scalar `a` == `b`. Args: a: A 0-d tensor. b: A 0-d tensor. message: Message to display when assertion failed. Returns: tf.Operation or None: The TensorFlow assertion operation, or None if can be statically asserted. """ if not is_tensor_object(a) and not is_tensor_object(b): if a != b: raise _make_assertion_error('a == b', '{!r} != {!r}'.format(a, b), message) else: return tf.assert_equal(a, b, message=message, name=name)
def __init__(self, distribution, tensor, n_samples=None, group_ndims=0, is_reparameterized=None, flow_origin=None, log_prob=None): """ Construct the :class:`StochasticTensor`. Args: distribution (tfsnippet.distributions.Distribution): The distribution of this :class:`StochasticTensor`. tensor (tf.Tensor or TensorWrapper): The samples or observations of this :class:`StochasticTensor`. n_samples (tf.Tensor or int): The number of samples taken in :class:`Distribution.sample`. If not :obj:`None`, the first dimension of `tensor` should be the sampling dimension. group_ndims (int or tf.Tensor): The number of dimensions to be considered as events group in samples. (default 0) is_reparameterized (bool): Whether or not the samples are re-parameterized? If not specified, will inherit from :attr:`tfsnippet.distributions.Distribution.is_reparameterized`. log_prob (Tensor or None): Pre-computed log-density of `tensor`, given `group_ndims`. flow_origin (StochasticTensor): The original stochastic tensor from the base distribution of a :class:`tfsnippet.FlowDistribution`. """ from tfsnippet.utils import TensorArgValidator, validate_group_ndims_arg if is_reparameterized is None: is_reparameterized = distribution.is_reparameterized if log_prob is not None and not is_tensor_object(log_prob): log_prob = tf.convert_to_tensor(log_prob) n_samples = validate_n_samples_arg(n_samples, 'n_samples') if n_samples is not None: with tf.name_scope('validate_n_samples'): validator = TensorArgValidator('n_samples') n_samples = validator.require_non_negative( validator.require_int32(n_samples)) group_ndims = validate_group_ndims_arg(group_ndims) super(StochasticTensor, self).__init__() self._self_distribution = distribution self._self_tensor = tf.convert_to_tensor(tensor) self._self_n_samples = n_samples self._self_group_ndims = group_ndims self._self_is_reparameterized = is_reparameterized self._self_flow_origin = flow_origin self._self_log_prob = log_prob self._self_prob = None
def sample(self, n_samples=None, group_ndims=0, is_reparameterized=None, compute_density=None, name=None): self._validate_sample_is_reparameterized_arg(is_reparameterized) if is_reparameterized is None: is_reparameterized = self.is_reparameterized with tf.name_scope(name, default_name='DiscretizedLogistic.sample'): # sample from uniform distribution sample_shape = self.batch_shape static_sample_shape = self.get_batch_shape() if n_samples is not None: sample_shape = tf.concat([[n_samples], sample_shape], 0) static_sample_shape = tf.TensorShape( [None if is_tensor_object(n_samples) else n_samples]). \ concatenate(static_sample_shape) u = tf.random_uniform(shape=sample_shape, minval=self._epsilon, maxval=1. - self._epsilon, dtype=self._param_dtype) u.set_shape(static_sample_shape) # inverse CDF of the logistic inverse_logistic_cdf = maybe_check_numerics( tf.log(u) - tf.log(1. - u), 'inverse_logistic_cdf') # obtain the actual sample scale = maybe_check_numerics(tf.exp(self.log_scale, name='scale'), 'scale') sample = self.mean + scale * inverse_logistic_cdf if self.discretize_sample: sample = self._discretize(sample) sample = maybe_check_numerics(sample, 'sample') sample = convert_to_tensor_and_cast(sample, self.dtype) if not is_reparameterized: sample = tf.stop_gradient(sample) t = StochasticTensor(distribution=self, tensor=sample, n_samples=n_samples, group_ndims=group_ndims, is_reparameterized=is_reparameterized) # compute the density if compute_density: compute_density_immediately(t) return t
def __init__(self, shape, dtype): """ Construct a new :class:`ZeroLogDet`. Args: shape (tuple[int] or Tensor): The shape of the log-det. dtype (tf.DType): The data type. """ if not is_tensor_object(shape): shape = tuple(int(v) for v in shape) self._self_shape = shape self._self_dtype = tf.as_dtype(dtype) self._self_tensor = None
def average(self, tensors, batch_size=None): """ Take the average of given tensors from different devices. If `batch_size` is specified, the tensors will be averaged with respect to the size of data fed to each device. Args: tensors (list[list[tf.Tensor]]): List of tensors from each device. batch_size (None or int or tf.Tensor): The optional batch size. Returns: list[tf.Tensor]: The averaged tensors. """ # check the arguments and try the fast path: only one tensor tensors = list(tensors) if not tensors: return [] length = len(tensors[0]) if length == 0: raise ValueError('`tensors` must be list of non-empty Tensor ' 'lists.') for t in tensors[1:]: if len(t) != length: raise ValueError('`tensors` must be list of Tensor lists of ' 'the same length.') if length == 1: return [t[0] for t in tensors] # do the slow path: average all tensors with tf.device(self.main_device), tf.name_scope('average_tensors'): if batch_size is None: return [tf.reduce_mean(tf.stack(t), axis=0) for t in tensors] k = len(self.work_devices) slice_len = (batch_size + k - 1) // k last_slice_size = batch_size - (k - 1) * slice_len if is_tensor_object(batch_size): to_float = tf.to_float else: to_float = float float_batch_size = to_float(batch_size) weights = tf.stack([to_float(slice_len) / float_batch_size] * (k - 1) + [to_float(last_slice_size) / float_batch_size]) return [ tf.reduce_sum(tf.stack(t) * weights, axis=0) for t in tensors ]
def __init__(self, loop, train_op, inputs, data_flow, feed_dict=None, metrics=None, summaries=None): """ Args: loop (TrainLoop): The training loop object. train_op (tf.Operation): The training operation. inputs (list[tf.Tensor]): The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by `data_flow`. data_flow (DataFlow): The training data flow. Each mini-batch must contain one array for each placeholder in `inputs`. feed_dict: The feed dict for training. It will be merged with the arrays provided by `data_flow` in each step. (default :obj:`None`) metrics (dict[str, tf.Tensor]): Metrics to be computed along with `train_op`. The keys are the names of metrics. summaries (tf.Tensor or Iterable[tf.Tensor]): A tensor or a list of summaries to be run and along with `train_op`, and later to be added to ``loop.summary_writer``. If ``loop.summary_writer`` is None, then no summary will be run. """ if loop.max_epoch is None and loop.max_step is None: raise ValueError('At least one of `max_epoch`, `max_step` should ' 'be configured for `loop`.') if summaries is not None and is_tensor_object(summaries): summaries = [summaries] super(Trainer, self).__init__(loop=loop) # memorize the arguments self._inputs = tuple(inputs or ()) self._data_flow = data_flow self._feed_dict = dict(feed_dict or ()) self._train_op = train_op self._metrics = dict(metrics or ()) self._summaries = list(summaries or ())
def reduce_group_ndims(operation, tensor, group_ndims, name=None): """ Reduce the last `group_ndims` dimensions in `tensor`, using `operation`. In :class:`~tfsnippet.distributions.Distribution`, when computing the (log-)densities of certain `tensor`, the last few dimensions may represent a group of events, thus should be accounted together. This method can be used to reduce these dimensions, for example: .. code-block:: python log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims) prob = reduce_group_ndims(tf.reduce_prod, log_prob, group_ndims) Args: operation: The operation for reducing the last `group_ndims` dimensions. It must receive `tensor` as the 1st argument, and `axis` as the 2nd argument. tensor: The tensor to be reduced. group_ndims: The number of dimensions at the end of `tensor` to be reduced. If it is a constant integer and is zero, then no operation will take place. name: TensorFlow name scope of the graph nodes. (default "reduce_group_ndims") Returns: tf.Tensor: The reduced tensor. Raises: ValueError: If `group_ndims` cannot be validated by :meth:`validate_group_ndims`. """ group_ndims = validate_group_ndims(group_ndims) with tf.name_scope(name, default_name='reduce_group_ndims'): if is_tensor_object(group_ndims): tensor = tf.cond( group_ndims > 0, lambda: operation(tensor, tf.range(-group_ndims, 0)), lambda: tensor ) else: if group_ndims > 0: tensor = operation(tensor, tf.range(-group_ndims, 0)) return tensor
def smart_cond(cond, true_fn, false_fn, name=None): """ Execute `true_fn` or `false_fn` according to `cond`. Args: cond (bool or tf.Tensor): A bool constant or a tensor. true_fn (() -> tf.Tensor): The function of the true branch. false_fn (() -> tf.Tensor): The function of the false branch. Returns: tf.Tensor: The output tensor. """ if is_tensor_object(cond): return tf.cond(cond, true_fn, false_fn, name=name) else: if cond: return true_fn() else: return false_fn()
def apply_log_det_factor(log_det, input, axis, value_ndims): shape = get_static_shape(input) assert (shape is not None) assert (len(shape) >= value_ndims) assert (value_ndims > 0) if axis < 0: axis = axis + len(shape) assert (axis >= 0) reduced_axis = [ a for a in range(-value_ndims, 0) if a + len(shape) != axis ] if reduced_axis: shape = get_dimensions_size(input, reduced_axis) if is_tensor_object(shape): log_det *= tf.cast(tf.reduce_prod(shape), log_det.dtype) else: log_det *= np.prod(shape) return log_det
def unflatten_from_ndims(x, static_front_shape, front_shape, name=None): """ The inverse transformation of :func:`flatten`. If both `static_front_shape` is None and `front_shape` is None, `x` will be returned without any change. Args: x (Tensor): The tensor to be unflatten. static_front_shape (tuple[int or None] or None): The static front shape. front_shape (tuple[int] or tf.Tensor or None): The front shape. Returns: tf.Tensor: The unflatten x. """ x = tf.convert_to_tensor(x) if static_front_shape is None and front_shape is None: return x if not x.get_shape(): raise ValueError('`x` is required to have known number of ' 'dimensions.') shape = get_static_shape(x) if len(shape) < 1: raise ValueError('`x` only has rank {}, required at least 1.'.format( len(shape))) if not is_tensor_object(front_shape): front_shape = tuple(front_shape) with tf.name_scope(name, default_name='unflatten', values=[x]): back_shape = shape[1:] static_back_shape = back_shape if None in back_shape: back_shape = tf.shape(x)[1:] if isinstance(front_shape, tuple) and isinstance(back_shape, tuple): x = tf.reshape(x, front_shape + back_shape) else: x = tf.reshape(x, tf.concat([front_shape, back_shape], axis=0)) x.set_shape( tf.TensorShape( list(static_front_shape) + list(static_back_shape))) return x
def assert_rank(x, ndims, message=None, name=None): """ Assert the rank of `x` is `ndims`. Args: x: A tensor. ndims (int or tf.Tensor): An integer, or a 0-d integer tensor. message: Message to display when assertion failed. Returns: tf.Operation or None: The TensorFlow assertion operation, or None if can be statically asserted. """ if not is_tensor_object(ndims) and get_static_shape(x) is not None: ndims = int(ndims) x_ndims = len(get_static_shape(x)) if x_ndims != ndims: raise _make_assertion_error('rank(x) == ndims', '{!r} != {!r}'.format(x_ndims, ndims), message) else: return tf.assert_rank(x, ndims, message=message, name=name)
def apply(self, input): """ Apply the layer on `input`, to produce output. Args: input (Tensor or list[Tensor]): The input tensor, or a list of input tensors. Returns: The output tensor, or a list of output tensors. """ if is_tensor_object(input) or isinstance(input, np.ndarray): input = tf.convert_to_tensor(input) ns_values = [input] else: input = [tf.convert_to_tensor(i) for i in input] ns_values = input if not self._has_built: self.build(input) with tf.name_scope(get_default_scope_name('apply', self), values=ns_values): return self._apply(input)
def maybe_tile(t, tile, name): if any(s != 1 for s in tile): if any(is_tensor_object(s) for s in tile): tile = tf.stack(tile, axis=0) t = tf.tile(t, tile, name=name) return t
def __init__(self, categorical, components, is_reparameterized=False): """ Construct a new :class:`Mixture`. Args: categorical (Categorical): The categorical distribution, indicating the probabilities of the mixture components. components (Iterable[Distribution]): The component distributions of the mixture. is_reparameterized (bool): Whether or not this mixture distribution is re-parameterized? If :obj:`True`, the `components` must all be re-parameterized. The `categorical` will be treated as constant, and the mixture samples will be composed by `one_hot(categorical samples) * stack([component samples])`, such that the gradients can be propagated back directly through these samples. If :obj:`False`, `tf.stop_gradient` will be applied on the mixture samples, such that no gradient will be propagated back through these samples. """ components = tuple(as_distribution(c) for c in components) is_reparameterized = bool(is_reparameterized) if not isinstance(categorical, Categorical): raise TypeError( '`categorical` must be a Categorical distribution: got {}'. format(categorical)) if is_tensor_object(categorical.n_categories): raise ValueError( 'Dynamic `categorical.n_categories` is not supported.') if not components: raise ValueError('`components` must not be empty.') if len(components) != categorical.n_categories: raise ValueError( '`len(components)` != `categorical.n_categories`: {} vs {}'. format(len(components), categorical.n_categories)) for i, c in enumerate(components): if is_reparameterized and not c.is_reparameterized: raise ValueError( '`is_reparameterized` is True, but the {}-th component ' 'is not re-parameterized: {}'.format(i, c)) for attr in ('dtype', 'is_continuous', 'value_ndims'): first_val = getattr(components[0], attr) for i, c in enumerate(components[1:], 1): c_val = getattr(c, attr) if c_val != first_val: raise ValueError( '`{}` of the {}-th component does not agree with the ' 'first component: {} vs {}'.format( attr, i, c_val, first_val)) # check the batch_shape of components, ensure they are equal batch_shape = components[0].batch_shape batch_static_shape = components[0].get_batch_shape() def is_static_batch_shape_match(c, batch_static_shape): batch_static_shape = batch_static_shape.as_list() c_batch_static_shape = c.get_batch_shape().as_list() equal = True if len(batch_static_shape) != len(c_batch_static_shape): equal = False else: for a, b in zip(batch_static_shape, c_batch_static_shape): if a is not None and b is not None and a != b: equal = False break return equal if not is_static_batch_shape_match(categorical, batch_static_shape): raise ValueError( 'Batch shape of `categorical` does not agree with ' 'the first component: {} vs {}'.format( categorical.get_batch_shape(), batch_static_shape)) for i, c in enumerate(components[1:], 1): if not is_static_batch_shape_match(c, batch_static_shape): raise ValueError( 'Batch shape of the {}-th component does not agree with ' 'the first component: {} vs {}'.format( i, c.get_batch_shape(), batch_static_shape)) def assert_batch_shape(c, batch_shape): c_batch_shape = c.batch_shape with assert_deps([ tf.assert_equal( tf.reduce_all( tf.equal( tf.concat([batch_shape, c_batch_shape], 0), tf.concat([c_batch_shape, batch_shape], 0))), True) ]) as asserted: if asserted: # pragma: no cover batch_shape = tf.identity(batch_shape) return batch_shape if settings.enable_assertions: with tf.name_scope('Mixture.init'): batch_shape = assert_batch_shape(categorical, batch_shape) for c in components[1:]: batch_shape = assert_batch_shape(c, batch_shape) self._categorical = categorical self._components = components super(Mixture, self).__init__( dtype=components[0].dtype, is_continuous=components[0].is_continuous, is_reparameterized=is_reparameterized, batch_shape=components[0].batch_shape, batch_static_shape=components[0].get_batch_shape(), value_ndims=components[0].value_ndims, )
def reshape_tail(input, ndims, shape, name=None): """ Reshape the tail (last) `ndims` into specified `shape`. Usage:: x = tf.zeros([2, 3, 4, 5, 6]) reshape_tail(x, 3, [-1]) # output: zeros([2, 3, 120]) reshape_tail(x, 1, [3, 2]) # output: zeros([2, 3, 4, 5, 3, 2]) Args: input (Tensor): The input tensor, at least `ndims` dimensions. ndims (int): To reshape this number of dimensions at tail. shape (Iterable[int] or tf.Tensor): The shape of the new tail. Returns: tf.Tensor: The reshaped tensor. """ input = tf.convert_to_tensor(input) if not is_tensor_object(shape): shape = list(int(s) for s in shape) neg_one_count = 0 for s in shape: if s <= 0: if s == -1: if neg_one_count > 0: raise ValueError('`shape` is not a valid shape: at ' 'most one `-1` can be specified.') else: neg_one_count += 1 else: raise ValueError('`shape` is not a valid shape: {} is ' 'not allowed.'.format(s)) with tf.name_scope(name or 'reshape_tail', values=[input]): # assert the dimension with assert_deps([ assert_rank_at_least( input, ndims, message='rank(input) must be at least ndims') ]) as asserted: if asserted: # pragma: no cover input = tf.identity(input) # compute the static shape static_input_shape = get_static_shape(input) static_output_shape = None if static_input_shape is not None: if ndims > 0: left_shape = static_input_shape[:-ndims] right_shape = static_input_shape[-ndims:] else: left_shape = static_input_shape right_shape = () # attempt to resolve "-1" in `shape` if isinstance(shape, list): if None not in right_shape: shape_size = int(np.prod([s for s in shape if s != -1])) right_shape_size = int(np.prod(right_shape)) if (-1 not in shape and shape_size != right_shape_size) or \ (-1 in shape and right_shape_size % shape_size != 0): raise ValueError( 'Cannot reshape the tail dimensions of ' '`input` into `shape`: input {!r}, ndims ' '{}, shape {}.'.format(input, ndims, shape)) if -1 in shape: pos = shape.index(-1) shape[pos] = right_shape_size // shape_size static_output_shape = left_shape + \ tuple(s if s != -1 else None for s in shape) static_output_shape = tf.TensorShape(static_output_shape) # compute the dynamic shape input_shape = get_shape(input) if ndims > 0: output_shape = concat_shapes([input_shape[:-ndims], shape]) else: output_shape = concat_shapes([input_shape, shape]) # do reshape output = tf.reshape(input, output_shape) output.set_shape(static_output_shape) return output
def is_log_det_shape_matches_input(log_det, input, value_ndims, name=None): """ Check whether or not the shape of `log_det` matches the shape of `input`. Basically, the shapes of `log_det` and `input` should satisfy:: if value_ndims > 0: assert(log_det.shape == input.shape[:-value_ndims]) else: assert(log_det.shape == input.shape) Args: log_det: Tensor, the log-determinant. input: Tensor, the input. value_ndims (int): The number of dimensions of each values sample. Returns: bool or tf.Tensor: A boolean or a tensor, indicating whether or not the shape of `log_det` matches the shape of `input`. """ if not is_tensor_object(log_det): log_det = tf.convert_to_tensor(log_det) if not is_tensor_object(input): input = tf.convert_to_tensor(input) value_ndims = int(value_ndims) with tf.name_scope(name or 'is_log_det_shape_matches_input'): log_det_shape = get_static_shape(log_det) input_shape = get_static_shape(input) # if both shapes have deterministic ndims, we can compare each axis # separately. if log_det_shape is not None and input_shape is not None: if len(log_det_shape) + value_ndims != len(input_shape): return False dynamic_axis = [] for i, (a, b) in enumerate(zip(log_det_shape, input_shape)): if a is None or b is None: dynamic_axis.append(i) elif a != b: return False if not dynamic_axis: return True log_det_shape = get_shape(log_det) input_shape = get_shape(input) return tf.reduce_all([ tf.equal(log_det_shape[i], input_shape[i]) for i in dynamic_axis ]) # otherwise we need to do a fully dynamic check, including check # ``log_det.ndims + value_ndims == input_shape.ndims`` is_ndims_matches = tf.equal( tf.rank(log_det) + value_ndims, tf.rank(input)) log_det_shape = get_shape(log_det) input_shape = get_shape(input) if value_ndims > 0: input_shape = input_shape[:-value_ndims] return tf.cond( is_ndims_matches, lambda: tf.reduce_all( tf.equal( # The following trick ensures we're comparing two tensors # with the same shape, such as to avoid some potential issues # about the cond operation. tf.concat([log_det_shape, input_shape], 0), tf.concat([input_shape, log_det_shape], 0), )), lambda: tf.constant(False, dtype=tf.bool))
def deconv2d(input, out_channels, kernel_size, strides=(1, 1), padding='same', channels_last=True, output_shape=None, activation_fn=None, normalizer_fn=None, weight_norm=False, gated=False, gate_sigmoid_bias=2., kernel=None, kernel_initializer=None, kernel_regularizer=None, kernel_constraint=None, use_bias=None, bias=None, bias_initializer=tf.zeros_initializer(), bias_regularizer=None, bias_constraint=None, trainable=True, name=None, scope=None): """ 2D deconvolutional layer. Args: input (Tensor): The input tensor, at least 4-d. out_channels (int): The channel numbers of the deconvolution output. kernel_size (int or (int, int)): Kernel size over spatial dimensions. strides (int or (int, int)): Strides over spatial dimensions. padding: One of {"valid", "same"}, case in-sensitive. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") output_shape: If specified, use this as the shape of the deconvolution output; otherwise compute the size of each dimension by:: output_size = input_size * strides if padding == 'valid': output_size += max(kernel_size - strides, 0) activation_fn: The activation function. normalizer_fn: The normalizer function. weight_norm (bool or (tf.Tensor) -> tf.Tensor)): If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on `kernel`. `use_scale` will be :obj:`True` if `normalizer_fn` is not specified, and :obj:`False` otherwise. The axis reduction will be determined by the layer. If it is a callable function, then it will be used to normalize the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`. The user must ensure the axis reduction is correct by themselves. gated (bool): Whether or not to use gate on output? `output = activation_fn(output) * sigmoid(gate)`. gate_sigmoid_bias (Tensor): The bias added to `gate` before applying the `sigmoid` activation. kernel (Tensor): Instead of creating a new variable, use this tensor. kernel_initializer: The initializer for `kernel`. Would be ``default_kernel_initializer(...)`` if not specified. kernel_regularizer: The regularizer for `kernel`. kernel_constraint: The constraint for `kernel`. use_bias (bool or None): Whether or not to use `bias`? If :obj:`True`, will always use bias. If :obj:`None`, will use bias only if `normalizer_fn` is not given. If :obj:`False`, will never use bias. Default is :obj:`None`. bias (Tensor): Instead of creating a new variable, use this tensor. bias_initializer: The initializer for `bias`. bias_regularizer: The regularizer for `bias`. bias_constraint: The constraint for `bias`. trainable (bool): Whether or not the parameters are trainable? Returns: tf.Tensor: The output tensor. """ input, in_channels, data_format = \ validate_conv2d_input(input, channels_last) out_channels = validate_positive_int_arg('out_channels', out_channels) dtype = input.dtype.base_dtype if gated: out_channels *= 2 # check functional arguments padding = validate_enum_arg('padding', str(padding).upper(), ['VALID', 'SAME']) strides = validate_conv2d_strides_tuple('strides', strides, channels_last) weight_norm_fn = validate_weight_norm_arg(weight_norm, axis=-1, use_scale=normalizer_fn is None) if use_bias is None: use_bias = normalizer_fn is None # get the specification of outputs and parameters kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size) kernel_shape = kernel_size + (out_channels, in_channels) bias_shape = (out_channels, ) given_h, given_w = None, None given_output_shape = output_shape if is_tensor_object(given_output_shape): given_output_shape = tf.convert_to_tensor(given_output_shape) elif given_output_shape is not None: given_h, given_w = given_output_shape # validate the parameters if kernel is not None: kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype) kernel = kernel_spec.validate('kernel', kernel) if kernel_initializer is None: kernel_initializer = default_kernel_initializer(weight_norm) if bias is not None: bias_spec = ParamSpec(shape=bias_shape, dtype=dtype) bias = bias_spec.validate('bias', bias) # the main part of the conv2d layer with tf.variable_scope(scope, default_name=name or 'deconv2d'): with tf.name_scope('output_shape'): # detect the input shape and axis arrangements input_shape = get_static_shape(input) if channels_last: c_axis, h_axis, w_axis = -1, -3, -2 else: c_axis, h_axis, w_axis = -3, -2, -1 output_shape = [None, None, None, None] output_shape[c_axis] = out_channels if given_output_shape is None: if input_shape[h_axis] is not None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if input_shape[w_axis] is not None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: if not is_tensor_object(given_output_shape): output_shape[h_axis] = given_h output_shape[w_axis] = given_w # infer the batch shape in 4-d batch_shape = input_shape[:-3] if None not in batch_shape: output_shape[0] = int(np.prod(batch_shape)) # now the static output shape is ready output_static_shape = tf.TensorShape(output_shape) # prepare for the dynamic batch shape if output_shape[0] is None: output_shape[0] = tf.reduce_prod(get_shape(input)[:-3]) # prepare for the dynamic spatial dimensions if output_shape[h_axis] is None or output_shape[w_axis] is None: if given_output_shape is None: input_shape = get_shape(input) if output_shape[h_axis] is None: output_shape[h_axis] = get_deconv_output_length( input_shape[h_axis], kernel_shape[0], strides[h_axis], padding) if output_shape[w_axis] is None: output_shape[w_axis] = get_deconv_output_length( input_shape[w_axis], kernel_shape[1], strides[w_axis], padding) else: assert (is_tensor_object(given_output_shape)) with assert_deps([ assert_rank(given_output_shape, 1), assert_scalar_equal(tf.size(given_output_shape), 2) ]): output_shape[h_axis] = given_output_shape[0] output_shape[w_axis] = given_output_shape[1] # compose the final dynamic shape if any(is_tensor_object(s) for s in output_shape): output_shape = tf.stack(output_shape) else: output_shape = tuple(output_shape) # create the variables if kernel is None: kernel = model_variable('kernel', shape=kernel_shape, dtype=dtype, initializer=kernel_initializer, regularizer=kernel_regularizer, constraint=kernel_constraint, trainable=trainable) if weight_norm_fn is not None: kernel = weight_norm_fn(kernel) maybe_add_histogram(kernel, 'kernel') kernel = maybe_check_numerics(kernel, 'kernel') if use_bias and bias is None: bias = model_variable('bias', shape=bias_shape, initializer=bias_initializer, regularizer=bias_regularizer, constraint=bias_constraint, trainable=trainable) maybe_add_histogram(bias, 'bias') bias = maybe_check_numerics(bias, 'bias') # flatten to 4d output, s1, s2 = flatten_to_ndims(input, 4) # do convolution or deconvolution output = tf.nn.conv2d_transpose(value=output, filter=kernel, output_shape=output_shape, strides=strides, padding=padding, data_format=data_format) if output_static_shape is not None: output.set_shape(output_static_shape) # add bias if use_bias: output = tf.nn.bias_add(output, bias, data_format=data_format) # apply the normalization function if specified if normalizer_fn is not None: output = normalizer_fn(output) # split into halves if gated if gated: output, gate = tf.split(output, 2, axis=c_axis) # apply the activation function if specified if activation_fn is not None: output = activation_fn(output) # apply the gate if required if gated: output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate') # unflatten back to original shape output = unflatten_from_ndims(output, s1, s2) maybe_add_histogram(output, 'output') output = maybe_check_numerics(output, 'output') return output
def get_training_loss(self, x, n_z=None): """ Get the training loss for this VAE. The variational solver is automatically chosen according to `z.is_reparameterized`, and the argument `n_z`, by the following rules: 1. If `z.is_reparameterized` is :obj:`True`, then: 1. If `n_z` > 1, use `iwae`. 2. If `n_z` == 1 or `n_z` is :obj:`None`, use `sgvb`. 2. If `z.is_reparameterized` is :obj:`False`, then: 1. If `n_z` > 1, use `vimco`. 2. If `n_z` == 1 or `n_z` is :obj:`None`, use `reinforce`. Dynamic `n_z` is not supported by this method. Also, Reweighted Wake-Sleep algorithm is not a choice of this method. To derive the training loss for either situation, use :meth:`chain` to obtain a :class:`~tfsnippet.variational.VariationalChain`, and further obtain the loss by `chain.vi.training.[algorithm]`. Args: x: The input observation `x`. n_z (int or None): Number of `z` samples to take. Must be :obj:`None` or a constant integer. Dynamic tensors are not accepted, since we cannot automatically choose a variational solver for undeterministic `n_z`. (default :obj:`None`) Returns: tf.Tensor: A 0-d tensor, the training loss which can be optimized by gradient descent. See Also: :class:`tfsnippet.variational.VariationalChain`, :class:`tfsnippet.variational.VariationalTrainingObjectives` """ with tf.name_scope('VAE.get_training_loss'): if n_z is not None: if is_tensor_object(n_z): raise TypeError('Cannot choose the variational solver ' 'automatically for dynamic `n_z`') n_z = validate_n_samples_arg(n_z, 'n_z') # derive the variational chain chain = self.chain(x, n_z) z = chain.variational['z'] # auto choose a variational solver for training loss if n_z is not None and n_z > 1: if z.is_reparameterized: solver = chain.vi.training.iwae else: solver = chain.vi.training.vimco else: if z.is_reparameterized: solver = chain.vi.training.sgvb else: solver = chain.vi.training.reinforce # derive the training loss return tf.reduce_mean(solver())
def pixelcnn_2d_sample(fn, inputs, height, width, channels_last=True, start=0, end=None, back_prop=False, parallel_iterations=1, swap_memory=False, name=None): """ Sample output from a PixelCNN 2D network, pixel-by-pixel. Args: fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`, the function to derive the outputs of PixelCNN 2D network at iteration `i`. `inputs` are the pixel-by-pixel outputs gathered through iteration `0` to iteration `i - 1`. The iteration index `i` may range from `0` to `height * width - 1`. inputs (Iterable[tf.Tensor]): The initial input tensors. All the tensors must be at least 4-d, with identical shape. height (int or tf.Tensor): The height of the outputs. width (int or tf.Tensor): The width of the outputs. channels_last (bool): Whether or not the channel axis is the last axis in `input`? (i.e., the data format is "NHWC") start (int or tf.Tensor): The start iteration, default `0`. end (int or tf.Tensor): The end (exclusive) iteration. Default `height * width`. back_prop, parallel_iterations, swap_memory: Arguments passed to :func:`tf.while_loop`. Returns: tuple[tf.Tensor]: The final outputs. """ from tfsnippet.layers.convolutional.utils import validate_conv2d_input # check the arguments def to_int(t): if is_tensor_object(t): return convert_to_tensor_and_cast(t, dtype=tf.int32) return int(t) height = to_int(height) width = to_int(width) inputs = list(inputs) if not inputs: raise ValueError('`inputs` must not be empty.') inputs[0], _, _ = validate_conv2d_input(inputs[0], channels_last=channels_last, arg_name='inputs[0]') input_spec = InputSpec(shape=get_static_shape(inputs[0])) for i, input in enumerate(inputs[1:], 1): inputs[i] = input_spec.validate('inputs[{}]'.format(i), input) # do pixelcnn sampling with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs): # the total size, start and end index total_size = height * width start = convert_to_tensor_and_cast(start, dtype=tf.int32) if end is None: end = convert_to_tensor_and_cast(total_size, dtype=tf.int32) else: end = convert_to_tensor_and_cast(end, dtype=tf.int32) # the mask shape if channels_last: mask_shape = [height, width, 1] else: mask_shape = [height, width] if any(is_tensor_object(t) for t in mask_shape): mask_shape = tf.stack(mask_shape, axis=0) # the input dynamic shape input_shape = get_shape(inputs[0]) # the pixelcnn sampling loop def loop_cond(idx, _): return idx < end def loop_body(idx, inputs): inputs = tuple(inputs) # prepare for the output mask selector = tf.reshape( tf.concat([ tf.ones([idx], dtype=tf.uint8), tf.zeros([1], dtype=tf.uint8), tf.ones([total_size - idx - 1], dtype=tf.uint8) ], axis=0), mask_shape) selector = tf.cast(broadcast_to_shape(selector, input_shape), dtype=tf.bool) # obtain the outputs outputs = list(fn(idx, inputs)) if len(outputs) != len(inputs): raise ValueError( 'The length of outputs != inputs: {} vs {}'.format( len(outputs), len(inputs))) # mask the outputs for i, (input, output) in enumerate(zip(inputs, outputs)): input_dtype = inputs[i].dtype.base_dtype output_dtype = output.dtype.base_dtype if output_dtype != input_dtype: raise TypeError( '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: ' '{output} vs {input}'.format(idx=i, output=output_dtype, input=input_dtype)) outputs[i] = tf.where(selector, input, output) return idx + 1, tuple(outputs) i0 = start _, outputs = tf.while_loop( cond=loop_cond, body=loop_body, loop_vars=(i0, tuple(inputs)), back_prop=back_prop, parallel_iterations=parallel_iterations, swap_memory=swap_memory, ) return outputs
def to_int(t): if is_tensor_object(t): return convert_to_tensor_and_cast(t, dtype=tf.int32) return int(t)
def to_tensor(t): return tf.convert_to_tensor(t) if not is_tensor_object(t) else t
def to_tensor(x): x = list(x) if any(is_tensor_object(t) for t in x): return tf.stack(list(x)) else: return tuple(x)
def gen_name_scope(): if is_tensor_object(group_ndims): with tf.name_scope(name, default_name='validate_group_ndims'): yield else: yield
def shift(input, shift, name=None): """ Shift each axis of `input` according to `shift`, but keep identical size. The extra content will be discarded if shifted outside the original size. Zeros will be padded to the front or end of shifted axes. Args: input (Tensor): The tensor to be shifted. shift (Iterable[int]): The shift length for each axes. It must be equal to the rank of `input`. For each axis, if its corresponding shift < 0, then the `input` will be shifted to left by `-shift` at that axis. If its shift > 0, then the `input` will be shifted to right by `shift` at that axis. Returns: tf.Tensor: The output tensor. """ shift = tuple(int(s) for s in shift) input = tf.convert_to_tensor(input) shape = get_static_shape(input) if shape is None: raise ValueError( 'The rank of `shape` is required to be deterministic: ' 'got {}'.format(input)) if len(shift) != len(shape): raise ValueError('The length of `shift` is required to equal the rank ' 'of `input`: shift {} vs input {}'.format( shift, input)) # cache for the dynamic shape def get_dynamic_shape(): if cached[0] is None: cached[0] = get_shape(input) return cached[0] cached = [None] # main routine with tf.name_scope(name, default_name='shift', values=[input]): # compute the slicing and padding arguments has_shift = False assert_ops = [] slice_begin = [] slice_size = [] paddings = [] err_msg = ('Cannot shift `input`: input {} vs shift {}'.format( input, shift)) for i, (axis_shift, axis_size) in enumerate(zip(shift, shape)): # fast approach: shift is zero, no slicing at the axis if axis_shift == 0: slice_begin.append(0) slice_size.append(-1) paddings.append((0, 0)) continue # slow approach: shift is not zero, should slice at the axis axis_shift_abs = abs(axis_shift) # we first check whether or not the axis size is big enough if axis_size is None: dynamic_axis_size = get_dynamic_shape()[i] assert_ops.append( tf.assert_greater_equal(dynamic_axis_size, axis_shift_abs, message=err_msg)) else: if axis_size < axis_shift_abs: raise ValueError(err_msg) # next, we compose the slicing range if axis_shift < 0: # shift to left slice_begin.append(-axis_shift) slice_size.append(-1) paddings.append((0, -axis_shift)) else: # shift to right slice_begin.append(0) if axis_size is None: slice_size.append(get_dynamic_shape()[i] - axis_shift) else: slice_size.append(axis_size - axis_shift) paddings.append((axis_shift, 0)) # mark the flag to indicate that we've got any axis to shift has_shift = True if assert_ops: with assert_deps(assert_ops) as asserted: if asserted: input = tf.identity(input) # no axis to shift, directly return the input if not has_shift: return input # do slicing and padding if any(is_tensor_object(s) for s in slice_size): slice_size = tf.stack(slice_size, axis=0) output = tf.slice(input, slice_begin, slice_size) output = tf.pad(output, paddings) return output
def broadcast_to_shape(x, shape, name=None): """ Broadcast `x` to match `shape`. If ``rank(x) > len(shape)``, only the tail dimensions will be broadcasted to match `shape`. Args: x: A tensor. shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape. Returns: tf.Tensor: The broadcasted tensor. """ # check the parameters x = tf.convert_to_tensor(x) x_shape = get_static_shape(x) ns_values = [x] if is_tensor_object(shape): shape = tf.convert_to_tensor(shape) ns_values.append(shape) else: shape = tuple(int(s) for s in shape) with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values): cannot_broadcast_msg = ( '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'. format(x, shape)) # fast routine: shape is tuple[int] and x_shape is all known, # we can use reshape + tile to do the broadcast, which should be faster # than using ``x * ones(shape)``. if isinstance(shape, tuple) and x_shape is not None and \ all(s is not None for s in x_shape): # reshape to have the same dimension if len(x_shape) < len(shape): x_shape = (1, ) * (len(shape) - len(x_shape)) + x_shape x = tf.reshape(x, x_shape) # tile to have the same shape tile = [] i = -1 while i > -len(shape) - 1: a, b = x_shape[i], shape[i] if a == 1 and b > 1: tile.append(b) elif a != b: raise ValueError(cannot_broadcast_msg) else: tile.append(1) i -= 1 tile = [1] * (len(x_shape) - len(shape)) + list(reversed(tile)) if any(s > 1 for s in tile): x = tf.tile(x, tile) return x # slow routine: we may need ``x * ones(shape)`` to do the broadcast assertions = [] post_assert_shape = False static_shape = tf.TensorShape(None) if isinstance(shape, tuple) and x_shape is not None: need_multiply_ones = False # it should always broadcast if len(x_shape) < len(shape) if len(x_shape) < len(shape): need_multiply_ones = True # check the consistency of x and shape static_shape_hint = [] # list to gather the static shape hint axis_to_check = [] # list to gather the axis to check i = -1 while i >= -len(shape) and i >= -len(x_shape): a, b = x_shape[i], shape[i] if a is None: axis_to_check.append(i) else: if a != b: if a == 1: need_multiply_ones = True else: raise ValueError(cannot_broadcast_msg) static_shape_hint.append(b) i -= 1 # compose the static shape hint if len(shape) < len(x_shape): static_shape = x_shape[:-len(shape)] elif len(shape) > len(x_shape): static_shape = shape[:-len(x_shape)] else: static_shape = () static_shape = tf.TensorShape(static_shape + tuple(reversed(static_shape_hint))) # compose the assertion operations and the multiply flag if axis_to_check: need_multiply_flags = [] x_dynamic_shape = tf.shape(x) for i in axis_to_check: assertions.append( tf.assert_equal(tf.logical_or( tf.equal(x_dynamic_shape[i], shape[i]), tf.equal(x_dynamic_shape[i], 1), ), True, message=cannot_broadcast_msg)) if len(x_shape) >= len(shape): need_multiply_flags.append( tf.not_equal(x_dynamic_shape[i], shape[i])) if not need_multiply_ones: need_multiply_ones = \ tf.reduce_any(tf.stack(need_multiply_flags)) else: # we have no ideal about what `shape` is here, thus we need to # assert the shape after ``x * ones(shape)``. need_multiply_ones = True post_assert_shape = True # do broadcast if `x_shape` != `shape` def multiply_branch(): with assert_deps(assertions): ones_template = tf.ones(shape, dtype=x.dtype.base_dtype) try: return x * ones_template except ValueError: # pragma: no cover raise ValueError(cannot_broadcast_msg) def identity_branch(): with assert_deps(assertions) as asserted: if asserted: return tf.identity(x) else: # pragma: no cover return x t = smart_cond(need_multiply_ones, multiply_branch, identity_branch) t.set_shape(static_shape) if post_assert_shape: post_assert_op = tf.assert_equal(tf.reduce_all( tf.equal(tf.shape(t)[-tf.size(shape):], shape)), True, message=cannot_broadcast_msg) with assert_deps([post_assert_op]) as asserted: if asserted: t = tf.identity(t) return t