def forward_with_state(self, inputs, weights, state, rng=None): """Computes this layer's output as part of a forward pass through the model. Args: inputs: Layer inputs (subclasses may use different inputs) weights: Layer weights state: Complete state of the layer rng: PRNG key Returns: A tuple (output, new_state). """ if not self.use_reference_code: # By default, an efficient, batched implementation is used. output, new_state, _, _ = self.forward_and_or_backward( inputs, weights, state, compute_output=True, update_state=True) return output, new_state # The reference implementation below provides a more readable overview of # what this class does. It's not optimized, however, and should only be used # when testing this class for correctness. if not isinstance(inputs, (tuple, list)): inputs = (inputs, ) batch_size = int(inputs[0].shape[0]) seqlen = inputs[0].shape[-2] d_model = inputs[0].shape[-1] output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] new_state = [] for example_idx in range(batch_size): for head_idx in range(self.n_heads): # pylint: disable=cell-var-from-loop single_inputs = jax.tree_map(lambda x: x[example_idx], inputs) single_weights = jax.tree_map(lambda w: w[head_idx], weights) single_state = jax.tree_map( lambda s: s[example_idx * self.n_heads + head_idx], state) # pylint: enable=cell-var-from-loop single_out, single_new_state = self.forward_unbatched( *single_inputs, weights=single_weights, state=single_state, update_state=True) new_state.append(single_new_state) output_accum[ example_idx] = output_accum[example_idx] + single_out output = np.stack(output_accum, 0) if new_state and jax.tree_leaves(new_state[0]): new_state = jax.tree_multimap(lambda *s: np.stack(s, 0), *new_state) else: new_state = state return output, new_state
def new_weights_and_state(self, input_signature): input_signature_unbatched = jax.tree_map( lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature) if isinstance(input_signature, (tuple, list)): batch_size = int(input_signature[0].shape[0]) else: batch_size = int(input_signature.shape[0]) weights = [] weight_rngs = self.new_rngs(self.n_heads) for i in range(self.n_heads): weights.append( self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i])) state = [] state_rngs = self.new_rngs(self.n_heads * batch_size) for i in range(self.n_heads * batch_size): state.append( self.create_state_unbatched(input_signature_unbatched, state_rngs[i])) stack_along_axis_0 = lambda *x: np.stack(x, axis=0) weights = jax.tree_multimap(stack_along_axis_0, *weights) state = jax.tree_multimap(stack_along_axis_0, *state) return weights, state
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None): if self._mode in ('train', 'eval'): x = inputs symbol_size = jnp.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. if inputs.shape[1] == 1: return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1) else: emb = [] for i in range(inputs.shape[0]): emb.append(jax.lax.dynamic_slice_in_dim( weights[0], state[i], inputs.shape[1], axis=0)) return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices(model_predict( reshape_by_device(x, n_devices), weights, state, np.stack(math.random.split(rng, n_devices)))) return math.nested_map(lambda y: np.mean(y, axis=0), res), state
def predict(x, weights, state, rng): """Predict function JIT-compileds and parallelized as requested.""" res, state = _combine_devices( model_predict(reshape_by_device(x, n_devices), weights, state, jnp.stack(math.random.split(rng, n_devices)))) if do_mean: return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state else: return res, state
def forward(self, inputs): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def _get_embeddings(self, t): """Get embeddings float[..., num_features]. Args: t: int[...] position (i.e. jnp.arange(..., jnp.int32)) Returns: embeddings: float[..., num_features] """ inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) bin_parity = inter_bin_idx % 2 bin_fraction = intra_bin_idx / self._time_bin_length embeddings = jnp.stack([ 1 / (1 + inter_bin_idx), bin_fraction, bin_parity.astype(jnp.float32), ], -1) assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape return embeddings
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): """Splits a key into a stream of random keys. This uses the little-endian counter mode. Args: key: uint32[2] the key to split lo: the range to start extracting from hi: the range to stop extracting from Returns: keys: uint32[hi - lo, 2] the split keys """ if not (key.shape == (2, ) and key.dtype == np.uint32): raise ValueError('key must be uint32[2]') if not hi < 2**32: # You shouldn't really be using more than half the key size anyways. raise NotImplementedError('only 32-bit sizes are supported') # Create a 64-bit counter: i_lo = np.arange(lo, hi, dtype=np.uint32) i_hi = np.zeros_like(i_lo) i = np.stack([i_lo, i_hi], axis=-1) return threefry_2x32_prf(key, i)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def pad_trajectories(trajectories, boundary=20): """Pad trajectories to a bucket length that is a multiple of boundary. Args: trajectories: list[(observation, actions, rewards)], where each observation is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the length of the list being B (batch size). boundary: int, bucket length, the actions and rewards are padded to integer multiples of boundary. Returns: tuple: (padding lengths, reward_mask, padded_observations, padded_actions, padded_rewards) where padded_observations is shaped (B, T+1) + OBS and padded_actions, padded_rewards & reward_mask are shaped (B, T). Where T is max(t) rounded up to an integer multiple of boundary. padded_length is how much padding we've added and reward_mask is 1s for actual rewards and 0s for the padding. """ # Let's compute max(t) over all trajectories. t_max = max(r.shape[0] for (_, _, r, _) in trajectories) # t_max is rounded to the next multiple of `boundary` boundary = int(boundary) bucket_length = boundary * int(np.ceil(float(t_max) / boundary)) # So all obs will be padded to t_max + 1 and actions and rewards to t_max. padded_observations = [] padded_actions = [] padded_rewards = [] padded_infos = collections.defaultdict(list) padded_lengths = [] reward_masks = [] for (o, a, r, i) in trajectories: # Determine the amount to pad, this holds true for obs, actions and rewards. num_to_pad = bucket_length + 1 - o.shape[0] padded_lengths.append(num_to_pad) if num_to_pad == 0: padded_observations.append(o) padded_actions.append(a) padded_rewards.append(r) reward_masks.append(onp.ones_like(r, dtype=np.int32)) if i: for k, v in i.items(): padded_infos[k].append(v) continue # First pad observations. padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] * (o.ndim - 1)) padding_value = get_padding_value(o.dtype) action_padding_value = get_padding_value(a.dtype) reward_padding_value = get_padding_value(r.dtype) padded_obs = lax.pad(o, padding_value, padding_config) padded_observations.append(padded_obs) # Now pad actions and rewards. padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] * (a.ndim - 1)) padded_action = lax.pad(a, action_padding_value, padding_config) padded_actions.append(padded_action) assert r.ndim == 1 padding_config = ((0, num_to_pad, 0), ) padded_reward = lax.pad(r, reward_padding_value, padding_config) padded_rewards.append(padded_reward) # Also create the mask to use later. reward_mask = onp.ones_like(r, dtype=np.int64) reward_masks.append(lax.pad(reward_mask, 0, padding_config)) if i: for k, v in i.items(): # Create a padding configuration for this value. padding_config = [(0, num_to_pad, 0) ] + [(0, 0, 0)] * (v.ndim - 1) padded_infos[k].append(lax.pad(v, 0.0, tuple(padding_config))) # Now stack these padded_infos if they exist. stacked_padded_infos = None if padded_infos: stacked_padded_infos = { k: np.stack(v) for k, v in padded_infos.items() } return padded_lengths, np.stack(reward_masks), np.stack( padded_observations), np.stack(padded_actions), np.stack( padded_rewards), stacked_padded_infos
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest if metrics is not None: self._metrics_dict = metrics else: self._metrics_dict = _DEFAULT_METRICS self._metrics_dict['loss'] = loss_fn self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if math.backend_name() == 'jax': # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)