def fn(dist_inputs, actions, q_values, act_log_probs, mask): del dist_inputs, actions, mask q_values = jnp.swapaxes(q_values, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples if preprocess: advantages = self._preprocess_advantages(advantages) return advantages
def fn(dist_inputs, actions, q_values, act_log_probs, mask): del dist_inputs, actions, mask q_values = jnp.swapaxes(q_values, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) if self._sample_all_discrete_actions: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) else: values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples if preprocess: advantages = self._preprocess_advantages(advantages) return advantages
def LossInput(dist_inputs, actions, advantages, old_dist_inputs, mask): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" del old_dist_inputs advantages = self._preprocess_advantages(advantages) dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape ) log_probs = self._policy_dist.log_prob(dist_inputs, actions) # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) advantages = jnp.swapaxes(advantages, 0, 1) mask = jnp.swapaxes(mask, 0, 1) return (log_probs, advantages, log_probs, mask)
def gather_fn(x): """Gather slices for a single tensor.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x elif x.shape[0] != batch_size: assert x.shape[0] % batch_size == 0 res = x.reshape((batch_size, -1,) + x.shape[1:]) res = np.swapaxes(res, 1, 2) res = res[batch_indices, beam_indices] res = np.swapaxes(res, 1, 2) res = res.reshape((-1,) + res.shape[2:]) return res else: return x[batch_indices, beam_indices]
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def LossInput(dist_inputs, actions, q_values, act_log_probs, mask): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) q_values = jnp.swapaxes(q_values, 0, 1) mask = jnp.swapaxes(mask, 0, 1) actions = jnp.swapaxes(actions, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting? # Reweight: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples advantages = self._preprocess_advantages(advantages) # Broadcast inputs and calculate log-probs dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) log_probs = self._policy_dist.log_prob(dist_inputs, actions) return (log_probs, advantages, act_log_probs, mask)
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: if self._sample_all_discrete_actions: # Since we want to sample all actions, start by creating their list. act = np.arange(self._vocab_size) # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. # Add extra dimenstions so it's the same dimensionality as dist_inputs. act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) if self._sample_all_discrete_actions: actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) actions = jnp.swapaxes(actions, 0, 1) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) if not self._sample_all_discrete_actions: actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def unflatten_beam_dim(x, batch_size, beam_size): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x if batch_size * beam_size == x.shape[0]: return x.reshape((batch_size, beam_size) + x.shape[1:]) else: assert x.shape[0] % (batch_size * beam_size) == 0 res = x.reshape((batch_size, beam_size, -1) + x.shape[1:]) res = np.swapaxes(res, 1, 2) res = res.reshape((-1, beam_size) + res.shape[3:]) return res
def flatten_beam_dim(x, batch_size=None): """Flattens the first two dimensions of a non-scalar array.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x if batch_size is not None and x.shape[0] != batch_size: assert x.shape[0] % batch_size == 0 res = x.reshape((batch_size, -1, x.shape[1]) + x.shape[2:]) res = np.swapaxes(res, 1, 2) res = res.reshape( (res.shape[0] * res.shape[1] * res.shape[2],) + res.shape[3:]) return res return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
def policy_batches_stream(self): """Use the RLTask self._task to create inputs to the policy model.""" # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? for np_trajectory in self._task.trajectory_batch_stream( self._policy_batch_size, epochs=self._replay_epochs, max_slice_length=self._max_slice_length, include_final_state=False, ): (q_values, actions) = self._run_value_model( np_trajectory.observations, np_trajectory.dist_inputs ) # TODO(pkozakowski): Try max here. values = jnp.mean(q_values, axis=0) if len(values.shape) != 2: raise ValueError('Values are expected to have shape ' + '[batch_size, length], got: %s' % str(values.shape)) if values.shape[0] != self._policy_batch_size: raise ValueError('Values first dimension should = policy batch size, ' + '%d != %d' %(values.shape[0], self._policy_batch_size)) # q_values shape: (n_samples, batch_size, length) # values shape: (batch_size, length) # Computing advantages by broadcasting over n_samples. advantages = q_values - values mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape) shapes.assert_shape_equals( advantages, (self._q_value_n_samples,) + values.shape ) shapes.assert_same_shape(mask, advantages) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. advantages = jnp.swapaxes(advantages, 0, 1) mask = jnp.swapaxes(mask, 0, 1) yield (np_trajectory.observations, actions, advantages, mask, mask)
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out
def attend( q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None, ): """Dot-product attention, with optional chunking and/or masking. Args: q: Query vectors, shape [q_len, d_qk] k: Key vectors, shape [kv_len, d_qk]; or None v: Value vectors, shape [kv_len, d_v] q_chunk_len: Set to non-zero to enable chunking for query vectors kv_chunk_len: Set to non-zero to enable chunking for key/value vectors n_chunks_before: Number of adjacent previous chunks to attend to n_chunks_after: Number of adjacent subsequent chunks to attend to mask_fn: TODO(kitaev) doc q_info: Query-associated metadata for masking kv_info: Key-associated metadata for masking dropout: Dropout rate rng: RNG for dropout Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention). """ assert v is not None share_qk = (k is None) if q_info is None: q_info = np.arange(q.shape[-2]) if kv_info is None and not share_qk: kv_info = np.arange(v.shape[-2]) # Split q/k/v into chunks along the time axis, if desired. if q_chunk_len is not None: q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) q_info = np.reshape(q_info, (-1, q_chunk_len)) if share_qk: assert kv_chunk_len is None or kv_chunk_len == q_chunk_len k = q kv_chunk_len = q_chunk_len kv_info = q_info elif kv_chunk_len is not None: k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) if kv_chunk_len is not None: v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) if share_qk: k = length_normalized(k) k = k / np.sqrt(k.shape[-1]) # Optionally include adjacent chunks. if q_chunk_len is not None or kv_chunk_len is not None: assert q_chunk_len is not None and kv_chunk_len is not None else: assert n_chunks_before == 0 and n_chunks_after == 0 k = look_adjacent(k, n_chunks_before, n_chunks_after) v = look_adjacent(v, n_chunks_before, n_chunks_after) kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) # Dot-product attention. dots = np.matmul(q, np.swapaxes(k, -1, -2)) # Masking if mask_fn is not None: dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) # Softmax. dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if dropout > 0.0: assert rng is not None # Dropout is broadcast across the bin dimension dropout_shape = (dots.shape[-2], dots.shape[-1]) # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) keep_prob = jax.lax.tie_in(dots, 1.0 - dropout) keep = jax.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. out = np.matmul(dots, v) out = np.reshape(out, (-1, out.shape[-1])) dots_logsumexp = np.reshape(dots_logsumexp, (-1, )) return out, dots_logsumexp