コード例 #1
0
ファイル: reformer.py プロジェクト: yangcaot/trax
 def forward(self, x):
   """Dropout, with broadcasting to save memory."""
   if self._mode == 'train' and self._rate > 0.0:
     noise_shape = list(x.shape)
     for dim in self._broadcast_dims:
       noise_shape[dim] = 1
     keep_prob = jax.lax.tie_in(self.rng, 1.0 - self._rate)
     keep = random.bernoulli(self.rng, keep_prob, tuple(noise_shape))
     multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob)
     return x * multiplier
   else:
     return x
コード例 #2
0
 def forward_with_state(self, x, weights, state, rng):
   """Dropout, with broadcasting to save memory."""
   del weights
   if rng is None:
     raise ValueError('BroadcastedDropout requires rng kwarg.')
   if self._mode == 'train' and self._rate > 0.0:
     noise_shape = list(x.shape)
     for dim in self._broadcast_dims:
       noise_shape[dim] = 1
     keep_prob = jax.lax.tie_in(rng, 1.0 - self._rate)
     keep = random.bernoulli(rng, keep_prob, tuple(noise_shape))
     multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob)
     return x * multiplier, state
   else:
     return x, state