Exemplo n.º 1
0
def Chunk(layer, chunk_size, pass_unchunkable=True):
  """Executes `layer` using batch chunks of size `chunk_size` to save memory."""
  if chunk_size < 1:
    return layer
  def reshape_to_chunks(x):
    chunk_batch = x.shape[0]
    size = chunk_size
    n_chunks = chunk_batch // size
    if chunk_batch % size != 0:
      if pass_unchunkable:
        n_chunks = 1
        size = chunk_batch
      else:
        raise ValueError(f'Chunk size {size} must divide batch '
                         f'size {chunk_batch}')
    return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:]))
  reshape_to_chunks_layer = base.PureLayer(
      lambda xs: fastmath.nested_map(reshape_to_chunks, xs),
      n_in=layer.n_in, n_out=layer.n_in, name='ReshapeToChunks')
  def reshape_from_chunks(x):
    batch_size = x.shape[0] * x.shape[1]
    return jnp.reshape(x, [batch_size] + list(x.shape[2:]))
  reshape_from_chunks_layer = base.PureLayer(
      lambda xs: fastmath.nested_map(reshape_from_chunks, xs),
      n_in=layer.n_out, n_out=layer.n_out, name='ReshapeFromChunks')
  return Serial(
      reshape_to_chunks_layer,
      Scan(layer, axis=0, n_carry=0, remat=True),
      reshape_from_chunks_layer,
  )
Exemplo n.º 2
0
def Select(indices, n_in=None, name=None):
    """Copies, reorders, or deletes stack elements according to `indices`.

  Args:
    indices: A list or tuple of 0-based indices to select elements relative to
        the top of the stack.
    n_in: Number of input elements to pop from the stack, and replace with
        those specified by `indices`. If not specified, its value will be
        calculated as `max(indices) + 1`.
    name: Descriptive name for this layer.

  Returns:
    Tensors, matching the number selected (`n_out = len(indices)`).
    Specifically:

      - n_out = 0: an empty tuple
      - n_out = 1: one tensor (NOT wrapped in a tuple)
      - n_out > 1: a tuple of tensors, with n_out items
  """
    if n_in is None:
        n_in = max(indices) + 1
    if name is None:
        name = f'Select{indices}'.replace(' ', '')

    def select(xs):  # pylint: disable=invalid-name
        if not isinstance(xs, (tuple, list)):
            xs = (xs, )
        selected = tuple(xs[i] for i in indices)
        return selected[0] if len(selected) == 1 else selected

    return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name)
Exemplo n.º 3
0
def PrintShape(n_in=1, msg=''):
  """Prints the shapes of `n_in` inputs and returns then unchanged."""
  def Fwd(xs):
    info = 'PrintShape: ' + msg + ' ' + ' '.join([str(x.shape) for x in xs])
    print(info)
    logging.info(info)
    return xs
  return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}')
Exemplo n.º 4
0
def PrintShape(n_in=1, msg=''):
  """Prints the shapes of `n_in` inputs and returns then unchanged."""
  def Fwd(xs):
    shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs])
    info = f'PrintShape: {msg}: [{shapes_and_dtypes}]'
    print(info)
    logging.info(info)
    return xs
  return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}')
Exemplo n.º 5
0
    def test_forward(self):
        layer = base.PureLayer(lambda x: 2 * x)

        # Use Layer.__call__.
        in_0 = np.array([1, 2])
        out_0 = layer(in_0)
        self.assertEqual(out_0.tolist(), [2, 4])

        # Use PureLayer.forward.
        in_1 = np.array([3, 4])
        out_1 = layer.forward(in_1, base.EMPTY_WEIGHTS)
        self.assertEqual(out_1.tolist(), [6, 8])

        # Use Layer.forward_with_state.
        in_2 = np.array([5, 6])
        out_2, _ = layer.forward_with_state(in_2)
        self.assertEqual(out_2.tolist(), [10, 12])
Exemplo n.º 6
0
def PrintShape(n_in=1, msg=''):
    """Prints the shapes of `n_in` inputs and returns then unchanged."""
    def Fwd(xs):
        def format_shape(x):  # pylint: disable = invalid-name
            return str(x.shape) + f'[{x.dtype}]'

        if n_in > 1:
            shapes_and_dtypes = ', '.join([format_shape(x) for x in xs])
        else:
            shapes_and_dtypes = format_shape(xs)
        info = f'PrintShape: {msg}: [{shapes_and_dtypes}]'
        print(info)
        logging.info(info)
        return xs

    return base.PureLayer(Fwd,
                          n_in=n_in,
                          n_out=n_in,
                          name=f'PrintShape_{n_in}')