def __call__(self, inputs, is_training): input_channels = inputs.shape[-1] self._blocks = [] for id_block in range(self._num_blocks): use_projection = id_block == 0 and self._channels != input_channels self._blocks.append( self._block_module(channels=self._channels, stride=self._stride if id_block == 0 else 1, use_projection=use_projection, normalize_fn=self._normalize_fn, name='block_%d' % id_block)) net = inputs for block in self._blocks: if self._remat: # Note: we can ignore cell-var-from-loop because the lambda is evaluated # inside every iteration of the loop. This is needed to go around the # way variables are passed to jax.remat. net = hk.remat(lambda x: block(x, is_training=is_training))( net) # pylint: disable=cell-var-from-loop else: net = block(net, is_training=is_training) return net
def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out)
def eval(self, context, target, z_loss=0., mask=0.0): input_len = context.shape[0] if self.rpe is not None: attn_bias = self.rpe(input_len, input_len, self.heads_per_shard, 32) else: attn_bias = 0 attn_bias += mask x = hk.remat(self.embed)(context) for l in self.transformer_layers: x = x + hk.remat(l)(x, attn_bias) return hk.remat(self.proj.loss)(x, target, z_loss)
def __init__(self, depth=50, num_classes: Optional[int] = 1000, width_mult: int = 1, normalize_fn: Optional[types.NormalizeFn] = None, name: Optional[Text] = None, remat: bool = False): """Creates ResNetV2 Haiku module. Args: depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202). num_classes: (int) Number of outputs in final layer. If None will not add a classification head and will return the output embedding. width_mult: multiplier for channel width. normalize_fn: normalization function, see helpers/utils.py name: Name of the module. remat: Whether to rematerialize intermediate activations (saves memory). """ super(ResNetV2, self).__init__(name=name) self._normalize_fn = normalize_fn self._num_classes = num_classes self._width_mult = width_mult self._strides = [1, 2, 2, 2] num_blocks = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3], } if depth not in num_blocks: raise ValueError( f'`depth` should be in {list(num_blocks.keys())} ({depth} given).' ) self._num_blocks = num_blocks[depth] if depth >= 50: self._block_module = BottleneckBlock self._channels = [256, 512, 1024, 2048] else: self._block_module = BasicBlock self._channels = [64, 128, 256, 512] self._initial_conv = hk.Conv2D(output_channels=64 * self._width_mult, kernel_shape=7, stride=2, with_bias=False, padding='SAME', name='initial_conv') if remat: self._initial_conv = hk.remat(self._initial_conv) self._block_groups = [] for i in range(4): self._block_groups.append( ResNetUnit(channels=self._channels[i] * self._width_mult, num_blocks=self._num_blocks[i], block_module=self._block_module, stride=self._strides[i], normalize_fn=self._normalize_fn, name='block_group_%d' % i, remat=remat)) if num_classes is not None: self._logits_layer = hk.Linear(output_size=num_classes, w_init=jnp.zeros, name='logits')
def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) return jnp.mean(mod(x))
def transformer(x, mask): return hk.remat(residual)(x, mask)