Exemplo n.º 1
0
    def testUnzip(self):
        n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]],
                         dtype=numpy.float32)
        t1 = tf.constant(n1)
        out = self.Run(functions.unzip(t1, 0, 4, 2))

        expected = numpy.array([[1., 2.], [5., 6.]], dtype=numpy.float32)
        testing.assert_allclose(expected, out[0], rtol=TOLERANCE)
        expected = numpy.array([[3., 4.], [7., 8.]], dtype=numpy.float32)
        testing.assert_allclose(expected, out[1], rtol=TOLERANCE)
Exemplo n.º 2
0
  def testUnzip(self):
    n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]],
                     dtype=numpy.float32)
    t1 = tf.constant(n1)
    out = self.Run(functions.unzip(t1, 0, 4, 2))

    expected = numpy.array([[1., 2.], [5., 6.]], dtype=numpy.float32)
    testing.assert_allclose(expected, out[0], rtol=TOLERANCE)
    expected = numpy.array([[3., 4.], [7., 8.]], dtype=numpy.float32)
    testing.assert_allclose(expected, out[1], rtol=TOLERANCE)
def unzip(input_layer, split_dim=0, num_splits=2):
  """Unzips the head Tensor along the split_dim into num_splits Equal chunks.

  Examples:

  * `[1, 2, 3, 4] -> [1, 3], [2, 4]`
  * `[[1, 1], [2, 2], [3, 3], [4, 4]] -> [[1, 1], [3, 3]], [[2, 2], [4, 4]]`

  Args:
    input_layer: The chainable object, supplied.
    split_dim: The dimension to split along. Defaults to batch.
    num_splits: The number of splits.
  Returns:
    A list of PrettyTensors.
  Raises:
    ValueError: If split_dim is out of range or isn't divided evenly by
      num_splits.
  """
  shape = input_layer.shape
  _check_split_dims(num_splits, split_dim, shape)
  splits = functions.unzip(input_layer, split_dim, shape[split_dim], num_splits)
  return input_layer.with_sequence(splits)
Exemplo n.º 4
0
def unzip(input_layer, split_dim=0, num_splits=2):
    """Unzips the head Tensor along the split_dim into num_splits Equal chunks.

  Examples:

  * `[1, 2, 3, 4] -> [1, 3], [2, 4]`
  * `[[1, 1], [2, 2], [3, 3], [4, 4]] -> [[1, 1], [3, 3]], [[2, 2], [4, 4]]`

  Args:
    input_layer: The chainable object, supplied.
    split_dim: The dimension to split along. Defaults to batch.
    num_splits: The number of splits.
  Returns:
    A list of PrettyTensors.
  Raises:
    ValueError: If split_dim is out of range or isn't divided evenly by
      num_splits.
  """
    shape = input_layer.shape
    _check_split_dims(num_splits, split_dim, shape)
    splits = functions.unzip(input_layer, split_dim, shape[split_dim],
                             num_splits)
    return input_layer.with_sequence(splits)