Beispiel #1
0
    def f(scope):
      a = scope.push('a')

      def g(scopes, _):
        scope, a = scopes
        self.assertEqual(a.parent, scope)
      
      lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,)))
Beispiel #2
0
 def f(scope):
     dense = lift.vmap(nn.dense,
                       in_axes=(0, None),
                       out_axes=0,
                       variable_axes={'params': 0},
                       split_rngs={'params': True})
     dense(scope.push('dense'), np.ones((3, 2)), 2)
Beispiel #3
0
def multi_head_dot_product_attention(scope: Scope,
                                     inputs_q: Array,
                                     inputs_kv: Array,
                                     bias: Optional[Array] = None,
                                     qkv_features: Optional[int] = None,
                                     out_features: Optional[int] = None,
                                     attn_fn: Callable = softmax_attn,
                                     batch_axes: Sequence[int] = (0, ),
                                     num_heads: int = 1,
                                     dtype=jnp.float32,
                                     broadcast_dropout=False):

    if qkv_features is None:
        qkv_features = inputs_q.shape[-1]
    if out_features is None:
        out_features = inputs_q.shape[-1]

    attn_fn = partial(dot_product_attention,
                      attn_fn=attn_fn,
                      qkv_features=qkv_features // num_heads,
                      out_features=out_features,
                      dtype=dtype)
    attn_fn = lift.vmap(attn_fn,
                        in_axes=(None, None, None),
                        out_axes=-2,
                        axis_size=num_heads,
                        variable_in_axes={'param': 0},
                        variable_out_axes={'param': 0},
                        split_rngs={
                            'param': True,
                            'dropout': not broadcast_dropout
                        })
    for axis in reversed(sorted(batch_axes)):
        attn_fn = lift.vmap(attn_fn,
                            in_axes=(axis, axis, axis),
                            out_axes=axis,
                            variable_in_axes={'param': None},
                            variable_out_axes={'param': None},
                            split_rngs={
                                'param': False,
                                'dropout': not broadcast_dropout
                            })

    y = attn_fn(scope, inputs_q, inputs_kv, bias)
    return y.mean(axis=-2)
Beispiel #4
0
def mlp_vmap(scope: Scope, x: Array,
             sizes: Sequence[int] = (8, 1),
             act_fn: Callable[[Array], Array] = nn.relu,
             share_params: bool = False):
  if share_params:
    dense_vmap = lift.vmap(nn.dense,
                           in_axes=(0, None),
                           variable_axes={'params': None},
                           split_rngs={'params': False})
  else:
    dense_vmap = lift.vmap(nn.dense,
                           in_axes=(0, None),
                           variable_axes={'params': 0},
                           split_rngs={'params': True})

  # hidden layers
  for size in sizes[:-1]:
    x = scope.child(dense_vmap, prefix='hidden_')(x, size)
    x = act_fn(x)

  # output layer
  return scope.child(dense_vmap, 'out')(x, sizes[-1])
Beispiel #5
0
def vmap(f,
         variable_axes={
             'params': -1,
             'const': None
         },
         split_rngs={
             'params': True,
         },
         in_axes=(Signal(-1, None), ),
         out_axes=Signal(-1, None)):
    # in_axes needs to be wrapped by a tuple, see Flax's lifted vmap implemetation:
    # https://github.com/google/flax/blob/82e9798274c927286878c4600b4b09650d1e7935/flax/core/lift.py#L395
    vf = lift.vmap(f,
                   variable_axes=variable_axes,
                   split_rngs=split_rngs,
                   in_axes=in_axes,
                   out_axes=out_axes)
    vf.__name__ = 'vmapped_' + f.__name__  # [Workaround]: lifted transformation does not keep the original name
    return vf