Example #1
0
def assert_log_det_shape_matches_input(log_det, input, value_ndims, name=None):
    """
    Assert the shape of `log_det` matches the shape of `input`.

    Args:
        log_det: Tensor, the log-determinant.
        input: Tensor, the input.
        value_ndims (int): The number of dimensions of each values sample.

    Returns:
        tf.Operation or None: The assertion operation, or None if the
            assertion can be made statically.
    """
    if not is_tensor_object(log_det):
        log_det = tf.convert_to_tensor(log_det)
    if not is_tensor_object(input):
        input = tf.convert_to_tensor(input)
    value_ndims = int(value_ndims)

    with tf.name_scope(name or 'assert_log_det_shape_matches_input'):
        cmp_result = is_log_det_shape_matches_input(log_det, input,
                                                    value_ndims)
        error_message = (
            'The shape of `log_det` does not match the shape of '
            '`input`: log_det {!r} vs input {!r}, value_ndims is {!r}'.format(
                log_det, input, value_ndims))

        if cmp_result is False:
            raise AssertionError(error_message)

        elif cmp_result is True:
            return None

        else:
            return tf.assert_equal(cmp_result, True, message=error_message)
def validate_n_samples(value, name):
    """
    Validate the `n_samples` argument.

    Args:
        value: An int32 value, a int32 :class:`tf.Tensor`, or :obj:`None`.
        name (str): Name of the argument (in error message).

    Returns:
        int or tf.Tensor: The validated `n_samples` argument value.

    Raises:
        TypeError or ValueError or None: If the value cannot be validated.
    """
    if is_tensor_object(value):

        @contextlib.contextmanager
        def mkcontext():
            with tf.name_scope('validate_n_samples'):
                yield
    else:

        @contextlib.contextmanager
        def mkcontext():
            yield

    if value is not None:
        with mkcontext():
            validator = TensorArgValidator(name=name)
            value = validator.require_positive(validator.require_int32(value))
    return value
Example #3
0
def broadcast_to_shape_strict(x, shape, name=None):
    """
    Broadcast `x` to match `shape`.

    This method requires `rank(x)` to be less than or equal to `len(shape)`.
    You may use :func:`broadcast_to_shape` instead, to allow the cases where
    ``rank(x) > len(shape)``.

    Args:
        x: A tensor.
        shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape.

    Returns:
        tf.Tensor: The broadcasted tensor.
    """
    # check the parameters
    x = tf.convert_to_tensor(x)
    x_shape = get_static_shape(x)
    ns_values = [x]
    if is_tensor_object(shape):
        shape = tf.convert_to_tensor(shape)
        ns_values.append(shape)
    else:
        shape = tuple(int(s) for s in shape)

    with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values):
        cannot_broadcast_msg = (
            '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'.
            format(x, shape))

        # assert ``rank(x) <= len(shape)``
        if isinstance(shape, tuple) and x_shape is not None:
            if len(x_shape) > len(shape):
                raise ValueError(cannot_broadcast_msg)
        elif isinstance(shape, tuple):
            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         len(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)
        else:
            with assert_deps(
                [assert_rank(shape, 1,
                             message=cannot_broadcast_msg)]) as asserted:
                if asserted:  # pragma: no cover
                    shape = tf.identity(shape)

            with assert_deps([
                    tf.assert_less_equal(tf.rank(x),
                                         tf.size(shape),
                                         message=cannot_broadcast_msg)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    x = tf.identity(x)

        # do broadcast
        return broadcast_to_shape(x, shape)
Example #4
0
    def test_is_tensor_object(self):
        for obj in [
                tf.constant(0.),  # type: tf.Tensor
                tf.get_variable('x', dtype=tf.float32, shape=()),
                TensorWrapper(),
                StochasticTensor(Mock(is_reparameterized=False),
                                 tf.constant(0.))
        ]:
            self.assertTrue(
                is_tensor_object(obj),
                msg='{!r} should be interpreted as a tensor object'.format(
                    obj))

        for obj in [1, '', object(), None, True, (), {}, [], np.zeros([1])]:
            self.assertFalse(
                is_tensor_object(obj),
                msg='{!r} should not be interpreted as a tensor object'.format(
                    obj))
Example #5
0
def assert_scalar_equal(a, b, message=None, name=None):
    """
    Assert 0-d scalar `a` == `b`.

    Args:
        a: A 0-d tensor.
        b: A 0-d tensor.
        message: Message to display when assertion failed.

    Returns:
        tf.Operation or None: The TensorFlow assertion operation,
            or None if can be statically asserted.
    """
    if not is_tensor_object(a) and not is_tensor_object(b):
        if a != b:
            raise _make_assertion_error('a == b', '{!r} != {!r}'.format(a, b),
                                        message)
    else:
        return tf.assert_equal(a, b, message=message, name=name)
Example #6
0
    def __init__(self,
                 distribution,
                 tensor,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=None,
                 flow_origin=None,
                 log_prob=None):
        """
        Construct the :class:`StochasticTensor`.

        Args:
            distribution (tfsnippet.distributions.Distribution): The
                distribution of this :class:`StochasticTensor`.
            tensor (tf.Tensor or TensorWrapper): The samples or observations
                of this :class:`StochasticTensor`.
            n_samples (tf.Tensor or int): The number of samples taken in
                :class:`Distribution.sample`.  If not :obj:`None`, the first
                dimension of `tensor` should be the sampling dimension.
            group_ndims (int or tf.Tensor): The number of dimensions to be
                considered as events group in samples. (default 0)
            is_reparameterized (bool): Whether or not the samples are
                re-parameterized?  If not specified, will inherit from
                :attr:`tfsnippet.distributions.Distribution.is_reparameterized`.
            log_prob (Tensor or None): Pre-computed log-density of `tensor`,
                given `group_ndims`.
            flow_origin (StochasticTensor): The original stochastic tensor
                from the base distribution of a
                :class:`tfsnippet.FlowDistribution`.
        """
        from tfsnippet.utils import TensorArgValidator, validate_group_ndims_arg

        if is_reparameterized is None:
            is_reparameterized = distribution.is_reparameterized
        if log_prob is not None and not is_tensor_object(log_prob):
            log_prob = tf.convert_to_tensor(log_prob)

        n_samples = validate_n_samples_arg(n_samples, 'n_samples')
        if n_samples is not None:
            with tf.name_scope('validate_n_samples'):
                validator = TensorArgValidator('n_samples')
                n_samples = validator.require_non_negative(
                    validator.require_int32(n_samples))

        group_ndims = validate_group_ndims_arg(group_ndims)

        super(StochasticTensor, self).__init__()
        self._self_distribution = distribution
        self._self_tensor = tf.convert_to_tensor(tensor)
        self._self_n_samples = n_samples
        self._self_group_ndims = group_ndims
        self._self_is_reparameterized = is_reparameterized
        self._self_flow_origin = flow_origin
        self._self_log_prob = log_prob
        self._self_prob = None
Example #7
0
    def sample(self,
               n_samples=None,
               group_ndims=0,
               is_reparameterized=None,
               compute_density=None,
               name=None):
        self._validate_sample_is_reparameterized_arg(is_reparameterized)
        if is_reparameterized is None:
            is_reparameterized = self.is_reparameterized

        with tf.name_scope(name, default_name='DiscretizedLogistic.sample'):
            # sample from uniform distribution
            sample_shape = self.batch_shape
            static_sample_shape = self.get_batch_shape()
            if n_samples is not None:
                sample_shape = tf.concat([[n_samples], sample_shape], 0)
                static_sample_shape = tf.TensorShape(
                    [None if is_tensor_object(n_samples) else n_samples]). \
                    concatenate(static_sample_shape)

            u = tf.random_uniform(shape=sample_shape,
                                  minval=self._epsilon,
                                  maxval=1. - self._epsilon,
                                  dtype=self._param_dtype)
            u.set_shape(static_sample_shape)

            # inverse CDF of the logistic
            inverse_logistic_cdf = maybe_check_numerics(
                tf.log(u) - tf.log(1. - u), 'inverse_logistic_cdf')

            # obtain the actual sample
            scale = maybe_check_numerics(tf.exp(self.log_scale, name='scale'),
                                         'scale')
            sample = self.mean + scale * inverse_logistic_cdf
            if self.discretize_sample:
                sample = self._discretize(sample)
            sample = maybe_check_numerics(sample, 'sample')
            sample = convert_to_tensor_and_cast(sample, self.dtype)

            if not is_reparameterized:
                sample = tf.stop_gradient(sample)

            t = StochasticTensor(distribution=self,
                                 tensor=sample,
                                 n_samples=n_samples,
                                 group_ndims=group_ndims,
                                 is_reparameterized=is_reparameterized)

            # compute the density
            if compute_density:
                compute_density_immediately(t)

            return t
Example #8
0
    def __init__(self, shape, dtype):
        """
        Construct a new :class:`ZeroLogDet`.

        Args:
            shape (tuple[int] or Tensor): The shape of the log-det.
            dtype (tf.DType): The data type.
        """
        if not is_tensor_object(shape):
            shape = tuple(int(v) for v in shape)
        self._self_shape = shape
        self._self_dtype = tf.as_dtype(dtype)
        self._self_tensor = None
Example #9
0
    def average(self, tensors, batch_size=None):
        """
        Take the average of given tensors from different devices.

        If `batch_size` is specified, the tensors will be averaged with respect
        to the size of data fed to each device.

        Args:
            tensors (list[list[tf.Tensor]]): List of tensors from each device.
            batch_size (None or int or tf.Tensor): The optional batch size.

        Returns:
            list[tf.Tensor]: The averaged tensors.
        """
        # check the arguments and try the fast path: only one tensor
        tensors = list(tensors)
        if not tensors:
            return []
        length = len(tensors[0])
        if length == 0:
            raise ValueError('`tensors` must be list of non-empty Tensor '
                             'lists.')
        for t in tensors[1:]:
            if len(t) != length:
                raise ValueError('`tensors` must be list of Tensor lists of '
                                 'the same length.')
        if length == 1:
            return [t[0] for t in tensors]

        # do the slow path: average all tensors
        with tf.device(self.main_device), tf.name_scope('average_tensors'):
            if batch_size is None:
                return [tf.reduce_mean(tf.stack(t), axis=0) for t in tensors]

            k = len(self.work_devices)
            slice_len = (batch_size + k - 1) // k
            last_slice_size = batch_size - (k - 1) * slice_len

            if is_tensor_object(batch_size):
                to_float = tf.to_float
            else:
                to_float = float

            float_batch_size = to_float(batch_size)
            weights = tf.stack([to_float(slice_len) / float_batch_size] *
                               (k - 1) +
                               [to_float(last_slice_size) / float_batch_size])

            return [
                tf.reduce_sum(tf.stack(t) * weights, axis=0) for t in tensors
            ]
Example #10
0
    def __init__(self,
                 loop,
                 train_op,
                 inputs,
                 data_flow,
                 feed_dict=None,
                 metrics=None,
                 summaries=None):
        """

        Args:
            loop (TrainLoop): The training loop object.
            train_op (tf.Operation): The training operation.
            inputs (list[tf.Tensor]): The input placeholders.
                The number of tensors, and the order of tensors, should
                both match the arrays of each mini-batch data, provided
                by `data_flow`.
            data_flow (DataFlow): The training data flow.
                Each mini-batch must contain one array for each placeholder
                in `inputs`.
            feed_dict: The feed dict for training.  It will be merged with
                the arrays provided by `data_flow` in each step.
                (default :obj:`None`)
            metrics (dict[str, tf.Tensor]): Metrics to be computed along with
                `train_op`.  The keys are the names of metrics.
            summaries (tf.Tensor or Iterable[tf.Tensor]): A tensor or a list
                of summaries to be run and along with `train_op`, and later
                to be added to ``loop.summary_writer``.
                If ``loop.summary_writer`` is None, then no summary will be run.
        """
        if loop.max_epoch is None and loop.max_step is None:
            raise ValueError('At least one of `max_epoch`, `max_step` should '
                             'be configured for `loop`.')
        if summaries is not None and is_tensor_object(summaries):
            summaries = [summaries]
        super(Trainer, self).__init__(loop=loop)

        # memorize the arguments
        self._inputs = tuple(inputs or ())
        self._data_flow = data_flow
        self._feed_dict = dict(feed_dict or ())
        self._train_op = train_op
        self._metrics = dict(metrics or ())
        self._summaries = list(summaries or ())
Example #11
0
def reduce_group_ndims(operation, tensor, group_ndims, name=None):
    """
    Reduce the last `group_ndims` dimensions in `tensor`, using `operation`.

    In :class:`~tfsnippet.distributions.Distribution`, when computing the
    (log-)densities of certain `tensor`, the last few dimensions
    may represent a group of events, thus should be accounted together.
    This method can be used to reduce these dimensions, for example:

    .. code-block:: python

         log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims)
         prob = reduce_group_ndims(tf.reduce_prod, log_prob, group_ndims)

    Args:
        operation: The operation for reducing the last `group_ndims`
            dimensions. It must receive `tensor` as the 1st argument, and
            `axis` as the 2nd argument.
        tensor: The tensor to be reduced.
        group_ndims: The number of dimensions at the end of `tensor` to be
            reduced.  If it is a constant integer and is zero, then no
            operation will take place.
        name: TensorFlow name scope of the graph nodes. (default
            "reduce_group_ndims")

    Returns:
        tf.Tensor: The reduced tensor.

    Raises:
        ValueError: If `group_ndims` cannot be validated by
            :meth:`validate_group_ndims`.
    """
    group_ndims = validate_group_ndims(group_ndims)
    with tf.name_scope(name, default_name='reduce_group_ndims'):
        if is_tensor_object(group_ndims):
            tensor = tf.cond(
                group_ndims > 0,
                lambda: operation(tensor, tf.range(-group_ndims, 0)),
                lambda: tensor
            )
        else:
            if group_ndims > 0:
                tensor = operation(tensor, tf.range(-group_ndims, 0))
    return tensor
Example #12
0
def smart_cond(cond, true_fn, false_fn, name=None):
    """
    Execute `true_fn` or `false_fn` according to `cond`.

    Args:
        cond (bool or tf.Tensor): A bool constant or a tensor.
        true_fn (() -> tf.Tensor): The function of the true branch.
        false_fn (() -> tf.Tensor): The function of the false branch.

    Returns:
        tf.Tensor: The output tensor.
    """
    if is_tensor_object(cond):
        return tf.cond(cond, true_fn, false_fn, name=name)
    else:
        if cond:
            return true_fn()
        else:
            return false_fn()
Example #13
0
def apply_log_det_factor(log_det, input, axis, value_ndims):
    shape = get_static_shape(input)
    assert (shape is not None)
    assert (len(shape) >= value_ndims)
    assert (value_ndims > 0)
    if axis < 0:
        axis = axis + len(shape)
        assert (axis >= 0)
    reduced_axis = [
        a for a in range(-value_ndims, 0) if a + len(shape) != axis
    ]

    if reduced_axis:
        shape = get_dimensions_size(input, reduced_axis)
        if is_tensor_object(shape):
            log_det *= tf.cast(tf.reduce_prod(shape), log_det.dtype)
        else:
            log_det *= np.prod(shape)

    return log_det
Example #14
0
def unflatten_from_ndims(x, static_front_shape, front_shape, name=None):
    """
    The inverse transformation of :func:`flatten`.

    If both `static_front_shape` is None and `front_shape` is None,
    `x` will be returned without any change.

    Args:
        x (Tensor): The tensor to be unflatten.
        static_front_shape (tuple[int or None] or None): The static front shape.
        front_shape (tuple[int] or tf.Tensor or None): The front shape.

    Returns:
        tf.Tensor: The unflatten x.
    """
    x = tf.convert_to_tensor(x)
    if static_front_shape is None and front_shape is None:
        return x
    if not x.get_shape():
        raise ValueError('`x` is required to have known number of '
                         'dimensions.')
    shape = get_static_shape(x)
    if len(shape) < 1:
        raise ValueError('`x` only has rank {}, required at least 1.'.format(
            len(shape)))
    if not is_tensor_object(front_shape):
        front_shape = tuple(front_shape)

    with tf.name_scope(name, default_name='unflatten', values=[x]):
        back_shape = shape[1:]
        static_back_shape = back_shape
        if None in back_shape:
            back_shape = tf.shape(x)[1:]
        if isinstance(front_shape, tuple) and isinstance(back_shape, tuple):
            x = tf.reshape(x, front_shape + back_shape)
        else:
            x = tf.reshape(x, tf.concat([front_shape, back_shape], axis=0))
            x.set_shape(
                tf.TensorShape(
                    list(static_front_shape) + list(static_back_shape)))
        return x
Example #15
0
def assert_rank(x, ndims, message=None, name=None):
    """
    Assert the rank of `x` is `ndims`.

    Args:
        x: A tensor.
        ndims (int or tf.Tensor): An integer, or a 0-d integer tensor.
        message: Message to display when assertion failed.

    Returns:
        tf.Operation or None: The TensorFlow assertion operation,
            or None if can be statically asserted.
    """
    if not is_tensor_object(ndims) and get_static_shape(x) is not None:
        ndims = int(ndims)
        x_ndims = len(get_static_shape(x))
        if x_ndims != ndims:
            raise _make_assertion_error('rank(x) == ndims',
                                        '{!r} != {!r}'.format(x_ndims,
                                                              ndims), message)
    else:
        return tf.assert_rank(x, ndims, message=message, name=name)
Example #16
0
    def apply(self, input):
        """
        Apply the layer on `input`, to produce output.

        Args:
            input (Tensor or list[Tensor]): The input tensor, or a list of
                input tensors.

        Returns:
            The output tensor, or a list of output tensors.
        """
        if is_tensor_object(input) or isinstance(input, np.ndarray):
            input = tf.convert_to_tensor(input)
            ns_values = [input]
        else:
            input = [tf.convert_to_tensor(i) for i in input]
            ns_values = input

        if not self._has_built:
            self.build(input)

        with tf.name_scope(get_default_scope_name('apply', self),
                           values=ns_values):
            return self._apply(input)
Example #17
0
 def maybe_tile(t, tile, name):
     if any(s != 1 for s in tile):
         if any(is_tensor_object(s) for s in tile):
             tile = tf.stack(tile, axis=0)
         t = tf.tile(t, tile, name=name)
     return t
Example #18
0
    def __init__(self, categorical, components, is_reparameterized=False):
        """
        Construct a new :class:`Mixture`.

        Args:
            categorical (Categorical): The categorical distribution,
                indicating the probabilities of the mixture components.
            components (Iterable[Distribution]): The component distributions
                of the mixture.
            is_reparameterized (bool): Whether or not this mixture distribution
                is re-parameterized?  If :obj:`True`, the `components` must
                all be re-parameterized.  The `categorical` will be treated
                as constant, and the mixture samples will be composed by
                `one_hot(categorical samples) * stack([component samples])`,
                such that the gradients can be propagated back directly
                through these samples.  If :obj:`False`, `tf.stop_gradient`
                will be applied on the mixture samples, such that no gradient
                will be propagated back through these samples.
        """
        components = tuple(as_distribution(c) for c in components)
        is_reparameterized = bool(is_reparameterized)

        if not isinstance(categorical, Categorical):
            raise TypeError(
                '`categorical` must be a Categorical distribution: got {}'.
                format(categorical))
        if is_tensor_object(categorical.n_categories):
            raise ValueError(
                'Dynamic `categorical.n_categories` is not supported.')

        if not components:
            raise ValueError('`components` must not be empty.')
        if len(components) != categorical.n_categories:
            raise ValueError(
                '`len(components)` != `categorical.n_categories`: {} vs {}'.
                format(len(components), categorical.n_categories))
        for i, c in enumerate(components):
            if is_reparameterized and not c.is_reparameterized:
                raise ValueError(
                    '`is_reparameterized` is True, but the {}-th component '
                    'is not re-parameterized: {}'.format(i, c))

        for attr in ('dtype', 'is_continuous', 'value_ndims'):
            first_val = getattr(components[0], attr)
            for i, c in enumerate(components[1:], 1):
                c_val = getattr(c, attr)
                if c_val != first_val:
                    raise ValueError(
                        '`{}` of the {}-th component does not agree with the '
                        'first component: {} vs {}'.format(
                            attr, i, c_val, first_val))

        # check the batch_shape of components, ensure they are equal
        batch_shape = components[0].batch_shape
        batch_static_shape = components[0].get_batch_shape()

        def is_static_batch_shape_match(c, batch_static_shape):
            batch_static_shape = batch_static_shape.as_list()
            c_batch_static_shape = c.get_batch_shape().as_list()
            equal = True

            if len(batch_static_shape) != len(c_batch_static_shape):
                equal = False
            else:
                for a, b in zip(batch_static_shape, c_batch_static_shape):
                    if a is not None and b is not None and a != b:
                        equal = False
                        break

            return equal

        if not is_static_batch_shape_match(categorical, batch_static_shape):
            raise ValueError(
                'Batch shape of `categorical` does not agree with '
                'the first component: {} vs {}'.format(
                    categorical.get_batch_shape(), batch_static_shape))

        for i, c in enumerate(components[1:], 1):
            if not is_static_batch_shape_match(c, batch_static_shape):
                raise ValueError(
                    'Batch shape of the {}-th component does not agree with '
                    'the first component: {} vs {}'.format(
                        i, c.get_batch_shape(), batch_static_shape))

        def assert_batch_shape(c, batch_shape):
            c_batch_shape = c.batch_shape
            with assert_deps([
                    tf.assert_equal(
                        tf.reduce_all(
                            tf.equal(
                                tf.concat([batch_shape, c_batch_shape], 0),
                                tf.concat([c_batch_shape, batch_shape], 0))),
                        True)
            ]) as asserted:
                if asserted:  # pragma: no cover
                    batch_shape = tf.identity(batch_shape)
            return batch_shape

        if settings.enable_assertions:
            with tf.name_scope('Mixture.init'):
                batch_shape = assert_batch_shape(categorical, batch_shape)
                for c in components[1:]:
                    batch_shape = assert_batch_shape(c, batch_shape)

        self._categorical = categorical
        self._components = components

        super(Mixture, self).__init__(
            dtype=components[0].dtype,
            is_continuous=components[0].is_continuous,
            is_reparameterized=is_reparameterized,
            batch_shape=components[0].batch_shape,
            batch_static_shape=components[0].get_batch_shape(),
            value_ndims=components[0].value_ndims,
        )
Example #19
0
def reshape_tail(input, ndims, shape, name=None):
    """
    Reshape the tail (last) `ndims` into specified `shape`.

    Usage::

        x = tf.zeros([2, 3, 4, 5, 6])
        reshape_tail(x, 3, [-1])  # output: zeros([2, 3, 120])
        reshape_tail(x, 1, [3, 2])  # output: zeros([2, 3, 4, 5, 3, 2])

    Args:
        input (Tensor): The input tensor, at least `ndims` dimensions.
        ndims (int): To reshape this number of dimensions at tail.
        shape (Iterable[int] or tf.Tensor): The shape of the new tail.

    Returns:
        tf.Tensor: The reshaped tensor.
    """
    input = tf.convert_to_tensor(input)
    if not is_tensor_object(shape):
        shape = list(int(s) for s in shape)
        neg_one_count = 0
        for s in shape:
            if s <= 0:
                if s == -1:
                    if neg_one_count > 0:
                        raise ValueError('`shape` is not a valid shape: at '
                                         'most one `-1` can be specified.')
                    else:
                        neg_one_count += 1
                else:
                    raise ValueError('`shape` is not a valid shape: {} is '
                                     'not allowed.'.format(s))

    with tf.name_scope(name or 'reshape_tail', values=[input]):
        # assert the dimension
        with assert_deps([
                assert_rank_at_least(
                    input, ndims, message='rank(input) must be at least ndims')
        ]) as asserted:
            if asserted:  # pragma: no cover
                input = tf.identity(input)

        # compute the static shape
        static_input_shape = get_static_shape(input)
        static_output_shape = None

        if static_input_shape is not None:
            if ndims > 0:
                left_shape = static_input_shape[:-ndims]
                right_shape = static_input_shape[-ndims:]
            else:
                left_shape = static_input_shape
                right_shape = ()

            # attempt to resolve "-1" in `shape`
            if isinstance(shape, list):
                if None not in right_shape:
                    shape_size = int(np.prod([s for s in shape if s != -1]))
                    right_shape_size = int(np.prod(right_shape))

                    if (-1 not in shape and shape_size != right_shape_size) or \
                            (-1 in shape and right_shape_size % shape_size != 0):
                        raise ValueError(
                            'Cannot reshape the tail dimensions of '
                            '`input` into `shape`: input {!r}, ndims '
                            '{}, shape {}.'.format(input, ndims, shape))

                    if -1 in shape:
                        pos = shape.index(-1)
                        shape[pos] = right_shape_size // shape_size

                static_output_shape = left_shape + \
                    tuple(s if s != -1 else None for s in shape)

        static_output_shape = tf.TensorShape(static_output_shape)

        # compute the dynamic shape
        input_shape = get_shape(input)
        if ndims > 0:
            output_shape = concat_shapes([input_shape[:-ndims], shape])
        else:
            output_shape = concat_shapes([input_shape, shape])

        # do reshape
        output = tf.reshape(input, output_shape)
        output.set_shape(static_output_shape)
        return output
Example #20
0
def is_log_det_shape_matches_input(log_det, input, value_ndims, name=None):
    """
    Check whether or not the shape of `log_det` matches the shape of `input`.

    Basically, the shapes of `log_det` and `input` should satisfy::

        if value_ndims > 0:
            assert(log_det.shape == input.shape[:-value_ndims])
        else:
            assert(log_det.shape == input.shape)

    Args:
        log_det: Tensor, the log-determinant.
        input: Tensor, the input.
        value_ndims (int): The number of dimensions of each values sample.

    Returns:
        bool or tf.Tensor: A boolean or a tensor, indicating whether or not
            the shape of `log_det` matches the shape of `input`.
    """
    if not is_tensor_object(log_det):
        log_det = tf.convert_to_tensor(log_det)
    if not is_tensor_object(input):
        input = tf.convert_to_tensor(input)
    value_ndims = int(value_ndims)

    with tf.name_scope(name or 'is_log_det_shape_matches_input'):
        log_det_shape = get_static_shape(log_det)
        input_shape = get_static_shape(input)

        # if both shapes have deterministic ndims, we can compare each axis
        # separately.
        if log_det_shape is not None and input_shape is not None:
            if len(log_det_shape) + value_ndims != len(input_shape):
                return False
            dynamic_axis = []

            for i, (a, b) in enumerate(zip(log_det_shape, input_shape)):
                if a is None or b is None:
                    dynamic_axis.append(i)
                elif a != b:
                    return False

            if not dynamic_axis:
                return True

            log_det_shape = get_shape(log_det)
            input_shape = get_shape(input)
            return tf.reduce_all([
                tf.equal(log_det_shape[i], input_shape[i])
                for i in dynamic_axis
            ])

        # otherwise we need to do a fully dynamic check, including check
        # ``log_det.ndims + value_ndims == input_shape.ndims``
        is_ndims_matches = tf.equal(
            tf.rank(log_det) + value_ndims, tf.rank(input))
        log_det_shape = get_shape(log_det)
        input_shape = get_shape(input)
        if value_ndims > 0:
            input_shape = input_shape[:-value_ndims]

        return tf.cond(
            is_ndims_matches,
            lambda: tf.reduce_all(
                tf.equal(
                    # The following trick ensures we're comparing two tensors
                    # with the same shape, such as to avoid some potential issues
                    # about the cond operation.
                    tf.concat([log_det_shape, input_shape], 0),
                    tf.concat([input_shape, log_det_shape], 0),
                )),
            lambda: tf.constant(False, dtype=tf.bool))
Example #21
0
def deconv2d(input,
             out_channels,
             kernel_size,
             strides=(1, 1),
             padding='same',
             channels_last=True,
             output_shape=None,
             activation_fn=None,
             normalizer_fn=None,
             weight_norm=False,
             gated=False,
             gate_sigmoid_bias=2.,
             kernel=None,
             kernel_initializer=None,
             kernel_regularizer=None,
             kernel_constraint=None,
             use_bias=None,
             bias=None,
             bias_initializer=tf.zeros_initializer(),
             bias_regularizer=None,
             bias_constraint=None,
             trainable=True,
             name=None,
             scope=None):
    """
    2D deconvolutional layer.

    Args:
        input (Tensor): The input tensor, at least 4-d.
        out_channels (int): The channel numbers of the deconvolution output.
        kernel_size (int or (int, int)): Kernel size over spatial dimensions.
        strides (int or (int, int)): Strides over spatial dimensions.
        padding: One of {"valid", "same"}, case in-sensitive.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        output_shape: If specified, use this as the shape of the
            deconvolution output; otherwise compute the size of each dimension
            by::

                output_size = input_size * strides
                if padding == 'valid':
                    output_size += max(kernel_size - strides, 0)

        activation_fn: The activation function.
        normalizer_fn: The normalizer function.
        weight_norm (bool or (tf.Tensor) -> tf.Tensor)):
            If :obj:`True`, apply :func:`~tfsnippet.layers.weight_norm` on
            `kernel`.  `use_scale` will be :obj:`True` if `normalizer_fn`
            is not specified, and :obj:`False` otherwise.  The axis reduction
            will be determined by the layer.

            If it is a callable function, then it will be used to normalize
            the `kernel` instead of :func:`~tfsnippet.layers.weight_norm`.
            The user must ensure the axis reduction is correct by themselves.
        gated (bool): Whether or not to use gate on output?
            `output = activation_fn(output) * sigmoid(gate)`.
        gate_sigmoid_bias (Tensor): The bias added to `gate` before applying
            the `sigmoid` activation.
        kernel (Tensor): Instead of creating a new variable, use this tensor.
        kernel_initializer: The initializer for `kernel`.
            Would be ``default_kernel_initializer(...)`` if not specified.
        kernel_regularizer: The regularizer for `kernel`.
        kernel_constraint: The constraint for `kernel`.
        use_bias (bool or None): Whether or not to use `bias`?
            If :obj:`True`, will always use bias.
            If :obj:`None`, will use bias only if `normalizer_fn` is not given.
            If :obj:`False`, will never use bias.
            Default is :obj:`None`.
        bias (Tensor): Instead of creating a new variable, use this tensor.
        bias_initializer: The initializer for `bias`.
        bias_regularizer: The regularizer for `bias`.
        bias_constraint: The constraint for `bias`.
        trainable (bool): Whether or not the parameters are trainable?

    Returns:
        tf.Tensor: The output tensor.
    """
    input, in_channels, data_format = \
        validate_conv2d_input(input, channels_last)
    out_channels = validate_positive_int_arg('out_channels', out_channels)
    dtype = input.dtype.base_dtype
    if gated:
        out_channels *= 2

    # check functional arguments
    padding = validate_enum_arg('padding',
                                str(padding).upper(), ['VALID', 'SAME'])
    strides = validate_conv2d_strides_tuple('strides', strides, channels_last)

    weight_norm_fn = validate_weight_norm_arg(weight_norm,
                                              axis=-1,
                                              use_scale=normalizer_fn is None)
    if use_bias is None:
        use_bias = normalizer_fn is None

    # get the specification of outputs and parameters
    kernel_size = validate_conv2d_size_tuple('kernel_size', kernel_size)
    kernel_shape = kernel_size + (out_channels, in_channels)
    bias_shape = (out_channels, )

    given_h, given_w = None, None
    given_output_shape = output_shape

    if is_tensor_object(given_output_shape):
        given_output_shape = tf.convert_to_tensor(given_output_shape)
    elif given_output_shape is not None:
        given_h, given_w = given_output_shape

    # validate the parameters
    if kernel is not None:
        kernel_spec = ParamSpec(shape=kernel_shape, dtype=dtype)
        kernel = kernel_spec.validate('kernel', kernel)
    if kernel_initializer is None:
        kernel_initializer = default_kernel_initializer(weight_norm)
    if bias is not None:
        bias_spec = ParamSpec(shape=bias_shape, dtype=dtype)
        bias = bias_spec.validate('bias', bias)

    # the main part of the conv2d layer
    with tf.variable_scope(scope, default_name=name or 'deconv2d'):
        with tf.name_scope('output_shape'):
            # detect the input shape and axis arrangements
            input_shape = get_static_shape(input)
            if channels_last:
                c_axis, h_axis, w_axis = -1, -3, -2
            else:
                c_axis, h_axis, w_axis = -3, -2, -1

            output_shape = [None, None, None, None]
            output_shape[c_axis] = out_channels
            if given_output_shape is None:
                if input_shape[h_axis] is not None:
                    output_shape[h_axis] = get_deconv_output_length(
                        input_shape[h_axis], kernel_shape[0], strides[h_axis],
                        padding)
                if input_shape[w_axis] is not None:
                    output_shape[w_axis] = get_deconv_output_length(
                        input_shape[w_axis], kernel_shape[1], strides[w_axis],
                        padding)
            else:
                if not is_tensor_object(given_output_shape):
                    output_shape[h_axis] = given_h
                    output_shape[w_axis] = given_w

            # infer the batch shape in 4-d
            batch_shape = input_shape[:-3]
            if None not in batch_shape:
                output_shape[0] = int(np.prod(batch_shape))

            # now the static output shape is ready
            output_static_shape = tf.TensorShape(output_shape)

            # prepare for the dynamic batch shape
            if output_shape[0] is None:
                output_shape[0] = tf.reduce_prod(get_shape(input)[:-3])

            # prepare for the dynamic spatial dimensions
            if output_shape[h_axis] is None or output_shape[w_axis] is None:
                if given_output_shape is None:
                    input_shape = get_shape(input)
                    if output_shape[h_axis] is None:
                        output_shape[h_axis] = get_deconv_output_length(
                            input_shape[h_axis], kernel_shape[0],
                            strides[h_axis], padding)
                    if output_shape[w_axis] is None:
                        output_shape[w_axis] = get_deconv_output_length(
                            input_shape[w_axis], kernel_shape[1],
                            strides[w_axis], padding)
                else:
                    assert (is_tensor_object(given_output_shape))
                    with assert_deps([
                            assert_rank(given_output_shape, 1),
                            assert_scalar_equal(tf.size(given_output_shape), 2)
                    ]):
                        output_shape[h_axis] = given_output_shape[0]
                        output_shape[w_axis] = given_output_shape[1]

            # compose the final dynamic shape
            if any(is_tensor_object(s) for s in output_shape):
                output_shape = tf.stack(output_shape)
            else:
                output_shape = tuple(output_shape)

        # create the variables
        if kernel is None:
            kernel = model_variable('kernel',
                                    shape=kernel_shape,
                                    dtype=dtype,
                                    initializer=kernel_initializer,
                                    regularizer=kernel_regularizer,
                                    constraint=kernel_constraint,
                                    trainable=trainable)

        if weight_norm_fn is not None:
            kernel = weight_norm_fn(kernel)

        maybe_add_histogram(kernel, 'kernel')
        kernel = maybe_check_numerics(kernel, 'kernel')

        if use_bias and bias is None:
            bias = model_variable('bias',
                                  shape=bias_shape,
                                  initializer=bias_initializer,
                                  regularizer=bias_regularizer,
                                  constraint=bias_constraint,
                                  trainable=trainable)
            maybe_add_histogram(bias, 'bias')
            bias = maybe_check_numerics(bias, 'bias')

        # flatten to 4d
        output, s1, s2 = flatten_to_ndims(input, 4)

        # do convolution or deconvolution
        output = tf.nn.conv2d_transpose(value=output,
                                        filter=kernel,
                                        output_shape=output_shape,
                                        strides=strides,
                                        padding=padding,
                                        data_format=data_format)
        if output_static_shape is not None:
            output.set_shape(output_static_shape)

        # add bias
        if use_bias:
            output = tf.nn.bias_add(output, bias, data_format=data_format)

        # apply the normalization function if specified
        if normalizer_fn is not None:
            output = normalizer_fn(output)

        # split into halves if gated
        if gated:
            output, gate = tf.split(output, 2, axis=c_axis)

        # apply the activation function if specified
        if activation_fn is not None:
            output = activation_fn(output)

        # apply the gate if required
        if gated:
            output = output * tf.sigmoid(gate + gate_sigmoid_bias, name='gate')

        # unflatten back to original shape
        output = unflatten_from_ndims(output, s1, s2)

        maybe_add_histogram(output, 'output')
        output = maybe_check_numerics(output, 'output')

    return output
Example #22
0
    def get_training_loss(self, x, n_z=None):
        """
        Get the training loss for this VAE.

        The variational solver is automatically chosen according to
        `z.is_reparameterized`, and the argument `n_z`, by the following rules:

        1. If `z.is_reparameterized` is :obj:`True`, then:

            1. If `n_z` > 1, use `iwae`.
            2. If `n_z` == 1 or `n_z` is :obj:`None`, use `sgvb`.

        2. If `z.is_reparameterized` is :obj:`False`, then:

            1. If `n_z` > 1, use `vimco`.
            2. If `n_z` == 1 or `n_z` is :obj:`None`, use `reinforce`.

        Dynamic `n_z` is not supported by this method.  Also, Reweighted
        Wake-Sleep algorithm is not a choice of this method.  To derive
        the training loss for either situation, use :meth:`chain`
        to obtain a :class:`~tfsnippet.variational.VariationalChain`,
        and further obtain the loss by `chain.vi.training.[algorithm]`.

        Args:
            x: The input observation `x`.
            n_z (int or None): Number of `z` samples to take.  Must be
                :obj:`None` or a constant integer.  Dynamic tensors are not
                accepted, since we cannot automatically choose a variational
                solver for undeterministic `n_z`. (default :obj:`None`)

        Returns:
            tf.Tensor: A 0-d tensor, the training loss which can be optimized
                by gradient descent.

        See Also:
            :class:`tfsnippet.variational.VariationalChain`,
            :class:`tfsnippet.variational.VariationalTrainingObjectives`
        """
        with tf.name_scope('VAE.get_training_loss'):
            if n_z is not None:
                if is_tensor_object(n_z):
                    raise TypeError('Cannot choose the variational solver '
                                    'automatically for dynamic `n_z`')
                n_z = validate_n_samples_arg(n_z, 'n_z')

            # derive the variational chain
            chain = self.chain(x, n_z)
            z = chain.variational['z']

            # auto choose a variational solver for training loss
            if n_z is not None and n_z > 1:
                if z.is_reparameterized:
                    solver = chain.vi.training.iwae
                else:
                    solver = chain.vi.training.vimco
            else:
                if z.is_reparameterized:
                    solver = chain.vi.training.sgvb
                else:
                    solver = chain.vi.training.reinforce

            # derive the training loss
            return tf.reduce_mean(solver())
Example #23
0
def pixelcnn_2d_sample(fn,
                       inputs,
                       height,
                       width,
                       channels_last=True,
                       start=0,
                       end=None,
                       back_prop=False,
                       parallel_iterations=1,
                       swap_memory=False,
                       name=None):
    """
    Sample output from a PixelCNN 2D network, pixel-by-pixel.

    Args:
        fn: `(i: tf.Tensor, inputs: tuple[tf.Tensor]) -> tuple[tf.Tensor]`,
            the function to derive the outputs of PixelCNN 2D network at
            iteration `i`.  `inputs` are the pixel-by-pixel outputs gathered
            through iteration `0` to iteration `i - 1`.  The iteration index
            `i` may range from `0` to `height * width - 1`.
        inputs (Iterable[tf.Tensor]): The initial input tensors.
            All the tensors must be at least 4-d, with identical shape.
        height (int or tf.Tensor): The height of the outputs.
        width (int or tf.Tensor): The width of the outputs.
        channels_last (bool): Whether or not the channel axis is the last
            axis in `input`? (i.e., the data format is "NHWC")
        start (int or tf.Tensor): The start iteration, default `0`.
        end (int or tf.Tensor): The end (exclusive) iteration.
            Default `height * width`.
        back_prop, parallel_iterations, swap_memory: Arguments passed to
            :func:`tf.while_loop`.

    Returns:
        tuple[tf.Tensor]: The final outputs.
    """
    from tfsnippet.layers.convolutional.utils import validate_conv2d_input

    # check the arguments
    def to_int(t):
        if is_tensor_object(t):
            return convert_to_tensor_and_cast(t, dtype=tf.int32)
        return int(t)

    height = to_int(height)
    width = to_int(width)

    inputs = list(inputs)
    if not inputs:
        raise ValueError('`inputs` must not be empty.')
    inputs[0], _, _ = validate_conv2d_input(inputs[0],
                                            channels_last=channels_last,
                                            arg_name='inputs[0]')
    input_spec = InputSpec(shape=get_static_shape(inputs[0]))
    for i, input in enumerate(inputs[1:], 1):
        inputs[i] = input_spec.validate('inputs[{}]'.format(i), input)

    # do pixelcnn sampling
    with tf.name_scope(name, default_name='pixelcnn_2d_sample', values=inputs):
        # the total size, start and end index
        total_size = height * width
        start = convert_to_tensor_and_cast(start, dtype=tf.int32)
        if end is None:
            end = convert_to_tensor_and_cast(total_size, dtype=tf.int32)
        else:
            end = convert_to_tensor_and_cast(end, dtype=tf.int32)

        # the mask shape
        if channels_last:
            mask_shape = [height, width, 1]
        else:
            mask_shape = [height, width]

        if any(is_tensor_object(t) for t in mask_shape):
            mask_shape = tf.stack(mask_shape, axis=0)

        # the input dynamic shape
        input_shape = get_shape(inputs[0])

        # the pixelcnn sampling loop
        def loop_cond(idx, _):
            return idx < end

        def loop_body(idx, inputs):
            inputs = tuple(inputs)

            # prepare for the output mask
            selector = tf.reshape(
                tf.concat([
                    tf.ones([idx], dtype=tf.uint8),
                    tf.zeros([1], dtype=tf.uint8),
                    tf.ones([total_size - idx - 1], dtype=tf.uint8)
                ],
                          axis=0), mask_shape)
            selector = tf.cast(broadcast_to_shape(selector, input_shape),
                               dtype=tf.bool)

            # obtain the outputs
            outputs = list(fn(idx, inputs))
            if len(outputs) != len(inputs):
                raise ValueError(
                    'The length of outputs != inputs: {} vs {}'.format(
                        len(outputs), len(inputs)))

            # mask the outputs
            for i, (input, output) in enumerate(zip(inputs, outputs)):
                input_dtype = inputs[i].dtype.base_dtype
                output_dtype = output.dtype.base_dtype
                if output_dtype != input_dtype:
                    raise TypeError(
                        '`outputs[{idx}].dtype` != `inputs[{idx}].dtype`: '
                        '{output} vs {input}'.format(idx=i,
                                                     output=output_dtype,
                                                     input=input_dtype))
                outputs[i] = tf.where(selector, input, output)

            return idx + 1, tuple(outputs)

        i0 = start
        _, outputs = tf.while_loop(
            cond=loop_cond,
            body=loop_body,
            loop_vars=(i0, tuple(inputs)),
            back_prop=back_prop,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
        )
        return outputs
Example #24
0
 def to_int(t):
     if is_tensor_object(t):
         return convert_to_tensor_and_cast(t, dtype=tf.int32)
     return int(t)
Example #25
0
 def to_tensor(t):
     return tf.convert_to_tensor(t) if not is_tensor_object(t) else t
Example #26
0
 def to_tensor(x):
     x = list(x)
     if any(is_tensor_object(t) for t in x):
         return tf.stack(list(x))
     else:
         return tuple(x)
Example #27
0
 def gen_name_scope():
     if is_tensor_object(group_ndims):
         with tf.name_scope(name, default_name='validate_group_ndims'):
             yield
     else:
         yield
Example #28
0
def shift(input, shift, name=None):
    """
    Shift each axis of `input` according to `shift`, but keep identical size.
    The extra content will be discarded if shifted outside the original size.
    Zeros will be padded to the front or end of shifted axes.

    Args:
        input (Tensor): The tensor to be shifted.
        shift (Iterable[int]): The shift length for each axes.
            It must be equal to the rank of `input`.
            For each axis, if its corresponding shift < 0, then the
            `input` will be shifted to left by `-shift` at that axis.
            If its shift > 0, then the `input` will be shifted to right
            by `shift` at that axis.

    Returns:
        tf.Tensor: The output tensor.
    """
    shift = tuple(int(s) for s in shift)
    input = tf.convert_to_tensor(input)
    shape = get_static_shape(input)

    if shape is None:
        raise ValueError(
            'The rank of `shape` is required to be deterministic: '
            'got {}'.format(input))
    if len(shift) != len(shape):
        raise ValueError('The length of `shift` is required to equal the rank '
                         'of `input`: shift {} vs input {}'.format(
                             shift, input))

    # cache for the dynamic shape
    def get_dynamic_shape():
        if cached[0] is None:
            cached[0] = get_shape(input)
        return cached[0]

    cached = [None]

    # main routine
    with tf.name_scope(name, default_name='shift', values=[input]):
        # compute the slicing and padding arguments
        has_shift = False
        assert_ops = []
        slice_begin = []
        slice_size = []
        paddings = []
        err_msg = ('Cannot shift `input`: input {} vs shift {}'.format(
            input, shift))

        for i, (axis_shift, axis_size) in enumerate(zip(shift, shape)):
            # fast approach: shift is zero, no slicing at the axis
            if axis_shift == 0:
                slice_begin.append(0)
                slice_size.append(-1)
                paddings.append((0, 0))
                continue

            # slow approach: shift is not zero, should slice at the axis
            axis_shift_abs = abs(axis_shift)

            # we first check whether or not the axis size is big enough
            if axis_size is None:
                dynamic_axis_size = get_dynamic_shape()[i]
                assert_ops.append(
                    tf.assert_greater_equal(dynamic_axis_size,
                                            axis_shift_abs,
                                            message=err_msg))
            else:
                if axis_size < axis_shift_abs:
                    raise ValueError(err_msg)

            # next, we compose the slicing range
            if axis_shift < 0:  # shift to left
                slice_begin.append(-axis_shift)
                slice_size.append(-1)
                paddings.append((0, -axis_shift))

            else:  # shift to right
                slice_begin.append(0)
                if axis_size is None:
                    slice_size.append(get_dynamic_shape()[i] - axis_shift)
                else:
                    slice_size.append(axis_size - axis_shift)
                paddings.append((axis_shift, 0))

            # mark the flag to indicate that we've got any axis to shift
            has_shift = True

        if assert_ops:
            with assert_deps(assert_ops) as asserted:
                if asserted:
                    input = tf.identity(input)

        # no axis to shift, directly return the input
        if not has_shift:
            return input

        # do slicing and padding
        if any(is_tensor_object(s) for s in slice_size):
            slice_size = tf.stack(slice_size, axis=0)

        output = tf.slice(input, slice_begin, slice_size)
        output = tf.pad(output, paddings)

        return output
Example #29
0
def broadcast_to_shape(x, shape, name=None):
    """
    Broadcast `x` to match `shape`.

    If ``rank(x) > len(shape)``, only the tail dimensions will be broadcasted
    to match `shape`.

    Args:
        x: A tensor.
        shape (tuple[int] or tf.Tensor): Broadcast `x` to match this shape.

    Returns:
        tf.Tensor: The broadcasted tensor.
    """
    # check the parameters
    x = tf.convert_to_tensor(x)
    x_shape = get_static_shape(x)
    ns_values = [x]
    if is_tensor_object(shape):
        shape = tf.convert_to_tensor(shape)
        ns_values.append(shape)
    else:
        shape = tuple(int(s) for s in shape)

    with tf.name_scope(name=name or 'broadcast_to_shape', values=ns_values):
        cannot_broadcast_msg = (
            '`x` cannot be broadcasted to match `shape`: x {!r} vs shape {!r}'.
            format(x, shape))

        # fast routine: shape is tuple[int] and x_shape is all known,
        # we can use reshape + tile to do the broadcast, which should be faster
        # than using ``x * ones(shape)``.
        if isinstance(shape, tuple) and x_shape is not None and \
                all(s is not None for s in x_shape):
            # reshape to have the same dimension
            if len(x_shape) < len(shape):
                x_shape = (1, ) * (len(shape) - len(x_shape)) + x_shape
                x = tf.reshape(x, x_shape)

            # tile to have the same shape
            tile = []
            i = -1
            while i > -len(shape) - 1:
                a, b = x_shape[i], shape[i]
                if a == 1 and b > 1:
                    tile.append(b)
                elif a != b:
                    raise ValueError(cannot_broadcast_msg)
                else:
                    tile.append(1)
                i -= 1
            tile = [1] * (len(x_shape) - len(shape)) + list(reversed(tile))
            if any(s > 1 for s in tile):
                x = tf.tile(x, tile)

            return x

        # slow routine: we may need ``x * ones(shape)`` to do the broadcast
        assertions = []
        post_assert_shape = False
        static_shape = tf.TensorShape(None)

        if isinstance(shape, tuple) and x_shape is not None:
            need_multiply_ones = False

            # it should always broadcast if len(x_shape) < len(shape)
            if len(x_shape) < len(shape):
                need_multiply_ones = True

            # check the consistency of x and shape
            static_shape_hint = []  # list to gather the static shape hint
            axis_to_check = []  # list to gather the axis to check
            i = -1
            while i >= -len(shape) and i >= -len(x_shape):
                a, b = x_shape[i], shape[i]
                if a is None:
                    axis_to_check.append(i)
                else:
                    if a != b:
                        if a == 1:
                            need_multiply_ones = True
                        else:
                            raise ValueError(cannot_broadcast_msg)
                static_shape_hint.append(b)
                i -= 1

            # compose the static shape hint
            if len(shape) < len(x_shape):
                static_shape = x_shape[:-len(shape)]
            elif len(shape) > len(x_shape):
                static_shape = shape[:-len(x_shape)]
            else:
                static_shape = ()
            static_shape = tf.TensorShape(static_shape +
                                          tuple(reversed(static_shape_hint)))

            # compose the assertion operations and the multiply flag
            if axis_to_check:
                need_multiply_flags = []
                x_dynamic_shape = tf.shape(x)

                for i in axis_to_check:
                    assertions.append(
                        tf.assert_equal(tf.logical_or(
                            tf.equal(x_dynamic_shape[i], shape[i]),
                            tf.equal(x_dynamic_shape[i], 1),
                        ),
                                        True,
                                        message=cannot_broadcast_msg))
                    if len(x_shape) >= len(shape):
                        need_multiply_flags.append(
                            tf.not_equal(x_dynamic_shape[i], shape[i]))

                if not need_multiply_ones:
                    need_multiply_ones = \
                        tf.reduce_any(tf.stack(need_multiply_flags))

        else:
            # we have no ideal about what `shape` is here, thus we need to
            # assert the shape after ``x * ones(shape)``.
            need_multiply_ones = True
            post_assert_shape = True

        # do broadcast if `x_shape` != `shape`
        def multiply_branch():
            with assert_deps(assertions):
                ones_template = tf.ones(shape, dtype=x.dtype.base_dtype)
            try:
                return x * ones_template
            except ValueError:  # pragma: no cover
                raise ValueError(cannot_broadcast_msg)

        def identity_branch():
            with assert_deps(assertions) as asserted:
                if asserted:
                    return tf.identity(x)
                else:  # pragma: no cover
                    return x

        t = smart_cond(need_multiply_ones, multiply_branch, identity_branch)
        t.set_shape(static_shape)

        if post_assert_shape:
            post_assert_op = tf.assert_equal(tf.reduce_all(
                tf.equal(tf.shape(t)[-tf.size(shape):], shape)),
                                             True,
                                             message=cannot_broadcast_msg)
            with assert_deps([post_assert_op]) as asserted:
                if asserted:
                    t = tf.identity(t)

        return t