def Chunk(layer, chunk_size, pass_unchunkable=True): """Executes `layer` using batch chunks of size `chunk_size` to save memory.""" if chunk_size < 1: return layer def reshape_to_chunks(x): chunk_batch = x.shape[0] size = chunk_size n_chunks = chunk_batch // size if chunk_batch % size != 0: if pass_unchunkable: n_chunks = 1 size = chunk_batch else: raise ValueError(f'Chunk size {size} must divide batch ' f'size {chunk_batch}') return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:])) reshape_to_chunks_layer = base.PureLayer( lambda xs: fastmath.nested_map(reshape_to_chunks, xs), n_in=layer.n_in, n_out=layer.n_in, name='ReshapeToChunks') def reshape_from_chunks(x): batch_size = x.shape[0] * x.shape[1] return jnp.reshape(x, [batch_size] + list(x.shape[2:])) reshape_from_chunks_layer = base.PureLayer( lambda xs: fastmath.nested_map(reshape_from_chunks, xs), n_in=layer.n_out, n_out=layer.n_out, name='ReshapeFromChunks') return Serial( reshape_to_chunks_layer, Scan(layer, axis=0, n_carry=0, remat=True), reshape_from_chunks_layer, )
def Select(indices, n_in=None, name=None): """Copies, reorders, or deletes stack elements according to `indices`. Args: indices: A list or tuple of 0-based indices to select elements relative to the top of the stack. n_in: Number of input elements to pop from the stack, and replace with those specified by `indices`. If not specified, its value will be calculated as `max(indices) + 1`. name: Descriptive name for this layer. Returns: Tensors, matching the number selected (`n_out = len(indices)`). Specifically: - n_out = 0: an empty tuple - n_out = 1: one tensor (NOT wrapped in a tuple) - n_out > 1: a tuple of tensors, with n_out items """ if n_in is None: n_in = max(indices) + 1 if name is None: name = f'Select{indices}'.replace(' ', '') def select(xs): # pylint: disable=invalid-name if not isinstance(xs, (tuple, list)): xs = (xs, ) selected = tuple(xs[i] for i in indices) return selected[0] if len(selected) == 1 else selected return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name)
def PrintShape(n_in=1, msg=''): """Prints the shapes of `n_in` inputs and returns then unchanged.""" def Fwd(xs): info = 'PrintShape: ' + msg + ' ' + ' '.join([str(x.shape) for x in xs]) print(info) logging.info(info) return xs return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}')
def PrintShape(n_in=1, msg=''): """Prints the shapes of `n_in` inputs and returns then unchanged.""" def Fwd(xs): shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs]) info = f'PrintShape: {msg}: [{shapes_and_dtypes}]' print(info) logging.info(info) return xs return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}')
def test_forward(self): layer = base.PureLayer(lambda x: 2 * x) # Use Layer.__call__. in_0 = np.array([1, 2]) out_0 = layer(in_0) self.assertEqual(out_0.tolist(), [2, 4]) # Use PureLayer.forward. in_1 = np.array([3, 4]) out_1 = layer.forward(in_1, base.EMPTY_WEIGHTS) self.assertEqual(out_1.tolist(), [6, 8]) # Use Layer.forward_with_state. in_2 = np.array([5, 6]) out_2, _ = layer.forward_with_state(in_2) self.assertEqual(out_2.tolist(), [10, 12])
def PrintShape(n_in=1, msg=''): """Prints the shapes of `n_in` inputs and returns then unchanged.""" def Fwd(xs): def format_shape(x): # pylint: disable = invalid-name return str(x.shape) + f'[{x.dtype}]' if n_in > 1: shapes_and_dtypes = ', '.join([format_shape(x) for x in xs]) else: shapes_and_dtypes = format_shape(xs) info = f'PrintShape: {msg}: [{shapes_and_dtypes}]' print(info) logging.info(info) return xs return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}')