Exemple #1
0
    def _sample_n(self, n, seed=None):
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=seed)  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = array_ops.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
        if mix_batch_size is None:
            mix_batch_size = math_ops.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(
            sample_shape=concat_vectors(
                [n],
                distribution_util.pick_vector(self.is_scalar_batch(),
                                              np.int32([]),
                                              [batch_size // mix_batch_size])),
            seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture"))
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = array_ops.reshape(ids,
                                shape=concat_vectors(
                                    [n],
                                    distribution_util.pick_vector(
                                        self.is_scalar_batch(), np.int32([]),
                                        np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements()
        if stride is None:
            stride = array_ops.reduce_prod(array_ops.shape(self.grid)[-2:])
        offset = math_ops.range(start=0,
                                limit=batch_size * stride,
                                delta=stride,
                                dtype=ids.dtype)

        weight = array_ops.gather(array_ops.reshape(self.grid, shape=[-1]),
                                  ids + offset)
        weight = weight[..., array_ops.newaxis]

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Exemple #2
0
  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.get_batch_shape()
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape()
        batch_size = array_ops.reduce_prod(batch_shape)
      static_event_shape = self.get_event_shape()
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape()

      # Get indices into the raw cat sampling tensor.  We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index.  The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat(([n_class * batch_size], event_shape), 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.concat((samples_shape,
                                                self.event_shape()), 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).concatenate(
              self.get_event_shape()))
      return ret
Exemple #3
0
def fill_lower_triangular(x,
                          validate_args=False,
                          name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
    # TODO(jvdillon): Replace this code with dedicated op when it exists.
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        if (x.get_shape().ndims is not None
                and x.get_shape()[-1].value is not None):
            d = x.get_shape()[-1].value
            # d = n(n+1)/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            d_inferred = n * (n + 1) / 2
            if d != d_inferred:
                raise ValueError(
                    "Input cannot be mapped to a lower triangular; "
                    "n*(n+1)/2 = %d != %d" % (d_inferred, d))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n(n+1)/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            if validate_args:
                is_valid_input_shape = check_ops.assert_equal(
                    n * (n + 1) / 2,
                    d,
                    message="Input cannot be mapped to a lower triangular.")
                n = control_flow_ops.with_dependencies([is_valid_input_shape],
                                                       n)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        def tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            # Build the ids statically; chose 512 because it implies 1MiB.
            if not contrib_framework.is_tensor(n) and n <= 512:
                ids = np.arange(n**2, dtype=np.int32)
                rows = (ids / n).astype(np.int32)  # Implicit floor.
                # We need to stop incrementing the index when we encounter
                # upper-triangular elements.  The idea here is to compute the
                # lower-right number of zeros then by "symmetry" subtract this from the
                # total number of zeros, n(n-1)/2.
                # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
                offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
                # We could also zero out when (rows < cols) == (rows < ids-n*rows).
                # mask = (ids <= (n + 1) * rows).astype(np.int32)
            else:
                ids = math_ops.range(n**2)
                rows = math_ops.cast(ids / n, dtype=dtypes.int32)
                offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                                       dtype=dtypes.int32)
            return ids - offset

        # Special-case non-batch case.
        if x.get_shape().ndims == 1:
            y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
            y = array_ops.matrix_band_part(y, -1, 0)
            y.set_shape(y.get_shape().merge_with(final_shape))
            return y

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape).astype(np.int32)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        # Assemble the tril_ids into batch,tril_id pairs.
        idx = array_ops.pack([
            array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
            array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        # Gather up, reshape, and return.
        y = array_ops.reshape(x, [-1, d])
        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat_v2([batch_shape, [n, n]], 0))
        y = array_ops.matrix_band_part(y, -1, 0)
        y.set_shape(y.get_shape().merge_with(final_shape))
        return y
  def _sample_n(self, n, seed=None):
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n],
            self.batch_shape_tensor(),
            self.event_shape_tensor()),
        seed=seed)   # shape: [n, B, e]
    x = [aff.forward(x) for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = array_ops.reduce_prod(self.batch_shape_tensor())
    mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
    if mix_batch_size is None:
      mix_batch_size = math_ops.reduce_prod(
          self.mixture_distribution.batch_shape_tensor())
    ids = self.mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size // mix_batch_size])),
        seed=distribution_util.gen_new_seed(
            seed, "vector_diffeomixture"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = array_ops.reshape(ids, shape=concat_vectors(
        [n],
        distribution_util.pick_vector(
            self.is_scalar_batch(),
            np.int32([]),
            np.int32([-1]))))

    # Stride `components * quadrature_size` for `batch_size` number of times.
    stride = self.grid.shape.with_rank_at_least(
        2)[-2:].num_elements()
    if stride is None:
      stride = array_ops.reduce_prod(
          array_ops.shape(self.grid)[-2:])
    offset = math_ops.range(start=0,
                            limit=batch_size * stride,
                            delta=stride,
                            dtype=ids.dtype)

    weight = array_ops.gather(
        array_ops.reshape(self.grid, shape=[-1]),
        ids + offset)
    # At this point, weight flattened all batch dims into one.
    # We also need to append a singleton to broadcast with event dims.
    if self.batch_shape.is_fully_defined():
      new_shape = [-1] + self.batch_shape.as_list() + [1]
    else:
      new_shape = array_ops.concat(
          ([-1], self.batch_shape_tensor(), [1]), axis=0)
    weight = array_ops.reshape(weight, shape=new_shape)

    if len(x) != 2:
      # We actually should have already triggered this exception. However as a
      # policy we're putting this exception wherever we exploit the bimixture
      # assumption.
      raise NotImplementedError("Currently only bimixtures are supported; "
                                "len(scale)={} is not 2.".format(len(x)))

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x
def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"):
  """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
  # TODO(jvdillon): Replace this code with dedicated op when it exists.
  with ops.name_scope(name, values=(x,)):
    x = ops.convert_to_tensor(x, name="x")
    if (x.get_shape().ndims is not None and
        x.get_shape()[-1].value is not None):
      d = x.get_shape()[-1].value
      # d = n(n+1)/2 implies n is:
      n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
      d_inferred = n * (n + 1) /2
      if d != d_inferred:
        raise ValueError("Input cannot be mapped to a lower triangular; "
                         "n*(n+1)/2 = %d != %d" % (d_inferred, d))
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([n, n]))
    else:
      d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
      # d = n(n+1)/2 implies n is:
      n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                        dtype=dtypes.int32)
      if validate_args:
        is_valid_input_shape = check_ops.assert_equal(
            n * (n + 1) / 2, d,
            message="Input cannot be mapped to a lower triangular.")
        n = control_flow_ops.with_dependencies([is_valid_input_shape], n)
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([None, None]))

    def tril_ids(n):
      """Internal helper to create vector of linear indices into y."""
      # Build the ids statically; chose 512 because it implies 1MiB.
      if not contrib_framework.is_tensor(n) and n <= 512:
        ids = np.arange(n**2, dtype=np.int32)
        rows = (ids / n).astype(np.int32)  # Implicit floor.
        # We need to stop incrementing the index when we encounter
        # upper-triangular elements.  The idea here is to compute the
        # lower-right number of zeros then by "symmetry" subtract this from the
        # total number of zeros, n(n-1)/2.
        # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
        offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
        # We could also zero out when (rows < cols) == (rows < ids-n*rows).
        # mask = (ids <= (n + 1) * rows).astype(np.int32)
      else:
        ids = math_ops.range(n**2)
        rows = math_ops.cast(ids / n, dtype=dtypes.int32)
        offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                               dtype=dtypes.int32)
      return ids - offset

    # Special-case non-batch case.
    if x.get_shape().ndims == 1:
      y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
      y = array_ops.matrix_band_part(y, -1, 0)
      y.set_shape(y.get_shape().merge_with(final_shape))
      return y

    # Make ids for each batch dim.
    if (x.get_shape().ndims is not None and
        x.get_shape()[:-1].is_fully_defined()):
      batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
      m = np.prod(batch_shape).astype(np.int32)
    else:
      batch_shape = array_ops.shape(x)[:-1]
      m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
    batch_ids = math_ops.range(m)

    # Assemble the tril_ids into batch,tril_id pairs.
    idx = array_ops.stack([
        array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
        array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
    ])
    idx = array_ops.transpose(idx, [1, 2, 0])

    # Gather up, reshape, and return.
    y = array_ops.reshape(x, [-1, d])
    y = array_ops.gather_nd(y, idx)
    y = array_ops.reshape(y, array_ops.concat([batch_shape, [n, n]], 0))
    y = array_ops.matrix_band_part(y, -1, 0)
    y.set_shape(y.get_shape().merge_with(final_shape))
    return y
def reduce_prod(x):
  """Same as `math_ops.reduce_prod` but statically if possible."""
  x_ = static_value(x)
  if x_ is not None:
    return np.prod(x_, dtype=x.dtype.as_numpy_dtype)
  return array_ops.reduce_prod(x)
Exemple #7
0
    def _sample_n(self, n, seed=None):
        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample_n(n, seed=seed)

            static_samples_shape = cat_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(cat_samples)
                samples_size = array_ops.size(cat_samples)
            static_batch_shape = self.get_batch_shape()
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape()
                batch_size = array_ops.reduce_prod(batch_shape)
            static_event_shape = self.get_event_shape()
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape()

            # Get indices into the raw cat sampling tensor.  We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            for c in range(self.num_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples_class_c = self.components[c].sample_n(n_class,
                                                              seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index.  The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.concat_v2(([n_class * batch_size], event_shape),
                                        0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.concat_v2((samples_shape, self.event_shape()), 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).concatenate(
                    self.get_event_shape()))
            return ret
def reduce_prod(x):
    """Same as `math_ops.reduce_prod` but statically if possible."""
    x_ = static_value(x)
    if x_ is not None:
        return np.prod(x_, dtype=x.dtype.as_numpy_dtype)
    return array_ops.reduce_prod(x)
def fill_lower_triangular(x, name="fill_lower_triangular"):
  """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Note: This function is very slow; possibly 10x slower than zero-ing out the
  upper-triangular portion of a full matrix.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  Args:
    x: `Tensor` representing lower triangular elements.
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.
  """
  with ops.name_scope(name, values=(x,)):
    x = ops.convert_to_tensor(x, name="x")
    ndims = x.get_shape().ndims
    if ndims is not None and x.get_shape()[-1].value is not None:
      d = x.get_shape()[-1].value
      # d = n^2/2 + n/2 implies n is:
      n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([n, n]))
    else:
      ndims = array_ops.rank(x)
      d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
      # d = n^2/2 + n/2 implies n is:
      n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                        dtype=dtypes.int32)
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([None, None]))

    # Make ids for each batch dim.
    if (x.get_shape().ndims is not None and
        x.get_shape()[:-1].is_fully_defined()):
      batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
      m = np.prod(batch_shape)
    else:
      batch_shape = array_ops.shape(x)[:-1]
      m = array_ops.reduce_prod(batch_shape)

    # Flatten batch dims.
    y = array_ops.reshape(x, [-1, d])

    # Prepend a zero to each row.
    y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])

    # Make ids for each batch dim.
    if x.get_shape()[:-1].is_fully_defined():
      m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32)
    else:
      m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
    batch_ids = math_ops.range(m)

    def make_tril_ids(n):
      """Internal helper to create vector of linear indices into y."""
      cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n])
      rows = array_ops.tile(
          array_ops.expand_dims(math_ops.range(n), -1), [1, n])
      pred = math_ops.greater(cols, rows)
      tril_ids = array_ops.tile(array_ops.reshape(
          math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols
      tril_ids = math_ops.select(pred,
                                 array_ops.zeros([n, n], dtype=dtypes.int32),
                                 tril_ids + 1)
      tril_ids = array_ops.reshape(tril_ids, [-1])
      return tril_ids
    tril_ids = make_tril_ids(n)

    # Assemble the ids into pairs.
    idx = array_ops.pack([
        array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]),
        array_ops.tile([tril_ids], [m, 1])])
    idx = array_ops.transpose(idx, [1, 2, 0])

    y = array_ops.gather_nd(y, idx)
    y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))

    y.set_shape(y.get_shape().merge_with(final_shape))

    return y
def fill_lower_triangular(x, name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Note: This function is very slow; possibly 10x slower than zero-ing out the
  upper-triangular portion of a full matrix.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  Args:
    x: `Tensor` representing lower triangular elements.
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.
  """
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        ndims = x.get_shape().ndims
        if ndims is not None and x.get_shape()[-1].value is not None:
            d = x.get_shape()[-1].value
            # d = n^2/2 + n/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            ndims = array_ops.rank(x)
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n^2/2 + n/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(batch_shape)

        # Flatten batch dims.
        y = array_ops.reshape(x, [-1, d])

        # Prepend a zero to each row.
        y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])

        # Make ids for each batch dim.
        if x.get_shape()[:-1].is_fully_defined():
            m = np.asarray(np.prod(x.get_shape()[:-1].as_list()),
                           dtype=np.int32)
        else:
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        def make_tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]),
                                     [n, n])
            rows = array_ops.tile(array_ops.expand_dims(math_ops.range(n), -1),
                                  [1, n])
            pred = math_ops.greater(cols, rows)
            tril_ids = array_ops.tile(
                array_ops.reshape(math_ops.cumsum(math_ops.range(n)), [n, 1]),
                [1, n]) + cols
            tril_ids = math_ops.select(
                pred, array_ops.zeros([n, n], dtype=dtypes.int32),
                tril_ids + 1)
            tril_ids = array_ops.reshape(tril_ids, [-1])
            return tril_ids

        tril_ids = make_tril_ids(n)

        # Assemble the ids into pairs.
        idx = array_ops.pack([
            array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n * n]),
            array_ops.tile([tril_ids], [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))

        y.set_shape(y.get_shape().merge_with(final_shape))

        return y