def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[0, state, :], 1), state + 1)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def forward_with_state(self, inputs, weights, state, rng): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. if inputs.shape[1] == 1: return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1) else: emb = [] for i in range(inputs.shape[0]): emb.append(jax.lax.dynamic_slice_in_dim( weights[0], state[i], inputs.shape[1], axis=0)) return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
def F(x): # TODO(afrozm): What to do in this case? if mode == 'predict': raise ValueError('MaskOfRightShiftedArray not implemented for predict.') mask = x != 0 if n_shifts == 0: return mask # Need to set (B, n_shifts, ...) section to True. trues_shape = (x.shape[0], n_shifts) + mask.shape[2:] trues = jnp.full(trues_shape, True) return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.concatenate(embs, -1) emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) return inputs + emb, state + inputs.shape[1] elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = np.concatenate(embs, -1) # return inputs + np.reshape(emb, inputs.shape), state return inputs + np.concatenate([ np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1), state else: emb = np.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state