def forward(self, x): """Dropout, with broadcasting to save memory.""" if self._mode == 'train' and self._rate > 0.0: noise_shape = list(x.shape) for dim in self._broadcast_dims: noise_shape[dim] = 1 keep_prob = jax.lax.tie_in(self.rng, 1.0 - self._rate) keep = random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob) return x * multiplier else: return x
def forward_with_state(self, x, weights, state, rng): """Dropout, with broadcasting to save memory.""" del weights if rng is None: raise ValueError('BroadcastedDropout requires rng kwarg.') if self._mode == 'train' and self._rate > 0.0: noise_shape = list(x.shape) for dim in self._broadcast_dims: noise_shape[dim] = 1 keep_prob = jax.lax.tie_in(rng, 1.0 - self._rate) keep = random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob) return x * multiplier, state else: return x, state