Ejemplo n.º 1
0
    def __call__(self, inputs, is_training):

        input_channels = inputs.shape[-1]

        self._blocks = []
        for id_block in range(self._num_blocks):
            use_projection = id_block == 0 and self._channels != input_channels
            self._blocks.append(
                self._block_module(channels=self._channels,
                                   stride=self._stride if id_block == 0 else 1,
                                   use_projection=use_projection,
                                   normalize_fn=self._normalize_fn,
                                   name='block_%d' % id_block))

        net = inputs
        for block in self._blocks:
            if self._remat:
                # Note: we can ignore cell-var-from-loop because the lambda is evaluated
                # inside every iteration of the loop. This is needed to go around the
                # way variables are passed to jax.remat.
                net = hk.remat(lambda x: block(x, is_training=is_training))(
                    net)  # pylint: disable=cell-var-from-loop
            else:
                net = block(net, is_training=is_training)
        return net
Ejemplo n.º 2
0
 def g(x, remat=False):
   mod = module_fn()
   if remat:
     mod = hk.remat(mod)
   out = mod(x)
   if isinstance(out, dict):
     out = out['loss']
   return jnp.mean(out)
    def eval(self, context, target, z_loss=0., mask=0.0):
        input_len = context.shape[0]

        if self.rpe is not None:
            attn_bias = self.rpe(input_len, input_len, self.heads_per_shard,
                                 32)
        else:
            attn_bias = 0

        attn_bias += mask

        x = hk.remat(self.embed)(context)

        for l in self.transformer_layers:
            x = x + hk.remat(l)(x, attn_bias)

        return hk.remat(self.proj.loss)(x, target, z_loss)
Ejemplo n.º 4
0
    def __init__(self,
                 depth=50,
                 num_classes: Optional[int] = 1000,
                 width_mult: int = 1,
                 normalize_fn: Optional[types.NormalizeFn] = None,
                 name: Optional[Text] = None,
                 remat: bool = False):
        """Creates ResNetV2 Haiku module.

    Args:
      depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
      num_classes: (int) Number of outputs in final layer. If None will not add
        a classification head and will return the output embedding.
      width_mult: multiplier for channel width.
      normalize_fn: normalization function, see helpers/utils.py
      name: Name of the module.
      remat: Whether to rematerialize intermediate activations (saves memory).
    """
        super(ResNetV2, self).__init__(name=name)
        self._normalize_fn = normalize_fn
        self._num_classes = num_classes
        self._width_mult = width_mult

        self._strides = [1, 2, 2, 2]
        num_blocks = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3],
            200: [3, 24, 36, 3],
        }
        if depth not in num_blocks:
            raise ValueError(
                f'`depth` should be in {list(num_blocks.keys())} ({depth} given).'
            )
        self._num_blocks = num_blocks[depth]

        if depth >= 50:
            self._block_module = BottleneckBlock
            self._channels = [256, 512, 1024, 2048]
        else:
            self._block_module = BasicBlock
            self._channels = [64, 128, 256, 512]

        self._initial_conv = hk.Conv2D(output_channels=64 * self._width_mult,
                                       kernel_shape=7,
                                       stride=2,
                                       with_bias=False,
                                       padding='SAME',
                                       name='initial_conv')

        if remat:
            self._initial_conv = hk.remat(self._initial_conv)

        self._block_groups = []
        for i in range(4):
            self._block_groups.append(
                ResNetUnit(channels=self._channels[i] * self._width_mult,
                           num_blocks=self._num_blocks[i],
                           block_module=self._block_module,
                           stride=self._strides[i],
                           normalize_fn=self._normalize_fn,
                           name='block_group_%d' % i,
                           remat=remat))

        if num_classes is not None:
            self._logits_layer = hk.Linear(output_size=num_classes,
                                           w_init=jnp.zeros,
                                           name='logits')
Ejemplo n.º 5
0
 def g(x, remat=False):
   mod = module_fn()
   if remat:
     mod = hk.remat(mod)
   return jnp.mean(mod(x))
 def transformer(x, mask):
     return hk.remat(residual)(x, mask)