def Reorder(x, params, output=None, **kwargs): """Reorder a tuple into another tuple. For example, we can re-order (x, y) into (y, x) or even (y, (x, y), y). The output argument specifies how to re-order, using integers that refer to indices in the input tuple. For example, if input = (x, y, z) then Reorder(input, output=(1, 0, 2)) = (y, x, z) Reorder(input, output=(0, 0)) = (x, x) Reorder(input, output=(0, (1, 1))) = (x, (y, y)) Reorder(input, output=((2, 0), (1, 1))) = ((z, x), (y, y)) By default (if no output is given) Reorder does nothing (Identity). Args: x: the input tuple to re-order. params: layer parameters (unused). output: the specification of the output tuple: a nested tuple of ints. **kwargs: other arguments (unused). Returns: The re-ordered tuple with the same shape as output. """ del params, kwargs if output is None: return x return base.nested_map(output, lambda i: x[i])
def new_parameters(self, input_shape, input_dtype, rng): def MakeShapeType(shape, dtype): if isinstance(dtype, (list, tuple)): return tuple(MakeShapeType(s, t) for s, t in zip(shape, dtype)) return base.ShapeType(shape=shape, dtype=dtype) params = [] pseudo_data = MakeShapeType(input_shape, input_dtype) for layer in self._layers: rng, layer_rng = backend.random.split(rng) cur_shape = base.nested_map(pseudo_data, lambda x: x.shape) cur_dtype = base.nested_map(pseudo_data, lambda x: x.dtype) param = layer.initialize(cur_shape, cur_dtype, layer_rng) pparam = layer._params # pylint: disable=protected-access pseudo_data = layer.pseudo_call(pseudo_data, pparam) params.append(param) return params
def new_parameters(self, input_shape, rng): params = [] cur_shape_and_type = base.to_shape_and_type( input_shape, self.default_input_is_int()) for layer in self._layers: rng, layer_rng = backend.random.split(rng) cur_shape = base.nested_map(cur_shape_and_type, lambda x: x.shape) param = layer.initialize(cur_shape, layer_rng) pparam = layer._params # pylint: disable=protected-access cur_shape_and_type = layer.output_shape(cur_shape_and_type, pparam) params.append(param) return params
def new_parameters(self, input_shape, input_dtype, rng): def MakeShapeType(shape, dtype): if isinstance(dtype, (list, tuple)): return tuple(MakeShapeType(s, t) for s, t in zip(shape, dtype)) return base.ShapeType(shape=shape, dtype=dtype) params = [] states = [] pseudo_xs = MakeShapeType(input_shape, input_dtype) for layer in self.sublayers: rng, layer_rng = backend.random.split(rng) # Give layer its args from pseudo_xs; treat 1-arg layer specially. is_stack_just_one_item = (_count_items(pseudo_xs) == 1) n_in = layer.n_inputs if n_in == 1 and is_stack_just_one_item: inputs = pseudo_xs elif n_in == 1: inputs = pseudo_xs[0] else: inputs = pseudo_xs[:n_in] in_shape = base.nested_map(inputs, lambda x: x.shape) in_dtype = base.nested_map(inputs, lambda x: x.dtype) param, state = layer.initialize(in_shape, in_dtype, layer_rng) pparam = layer._params # pylint: disable=protected-access outputs, _ = layer.pseudo_call(inputs, pparam, state) # Push outputs onto remaining pseudo_xs (if any). if n_in < _count_items(pseudo_xs): if layer.n_outputs == 1: outputs = (outputs, ) pseudo_xs = outputs + pseudo_xs[n_in:] else: pseudo_xs = outputs # NOTE: can be single value or tuple. params.append(param) states.append(state) return params, states
def call(self, x, params=(), **kwargs): del params, kwargs if self._output is None: return x return base.nested_map(self._output, lambda i: self._map(x, i))
def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) return layers.nested_map(normalize, grad_tree)
def _reorder_shape(input_shape, output=None): # pylint: disable=invalid-name """Helper to determine the shape of reorder output.""" if output is None: return input_shape return base.nested_map(output, lambda i: input_shape[i])
def output_shape_fun(self, input_shape): if self._output is None: return input_shape return base.nested_map(self._output, lambda i: self._map(input_shape, i))
def output_shape(self, input_shape): if self._output is None: return input_shape return base.nested_map(self._output, lambda i: input_shape[i])