Beispiel #1
0
    def __init__(self, params: InstantiableParams) -> None:
        super().__init__(params)
        p = self.params
        asserts.not_none(p.self_atten_tpl)
        self.create_child('self_atten', p.self_atten_tpl)

        self.create_child('norm', p.norm_tpl)

        # Initialize residual dropout.
        params = p.residual_dropout_tpl.Copy()
        params.keep_prob = (1.0 - p.residual_dropout_prob)
        self.create_child('residual_dropout', params)
Beispiel #2
0
  def __init__(self, params: InstantiableParams) -> None:
    """Constructor for the learner."""
    assert params.name, ('Learner params for %s must have a "name"' %
                         self.__class__.__name__)
    module_name = params.name
    NestedMap.CheckKey(module_name)

    self._params = params.Copy()

    p = self.params
    asserts.not_none(p.optimizer)
    asserts.not_none(p.loss_name)
    self._optimizer = p.optimizer.Instantiate()
    self._grad_tx = self.optimizer.get_grad_transformation()
Beispiel #3
0
 def __init__(self, params: InstantiableParams) -> None:
   """Constructor for the MultiOptimizer learner."""
   super().__init__(params)
   p = self.params
   asserts.not_none(p.optimizer)
   asserts.not_none(p.loss_name)
   if len(p.auxiliary_optimizers) != len(p.auxiliary_regex):
     raise ValueError('The length of the auxiliary regex must match the length'
                      'of the auxiliary optimizers.')
   self._optimizer = p.optimizer.Instantiate()
   self._auxiliary_optimizers = [
       opt.Instantiate() for opt in p.auxiliary_optimizers
   ]
   self._grad_tx_fn = self._optimizer.get_grad_transformation
   self._auxiliary_grad_tx_fn = [
       opt.get_grad_transformation for opt in self._auxiliary_optimizers
   ]
Beispiel #4
0
    def __init__(self, params):
        """Initializes GroupNorm layer and checks parameters."""
        super().__init__(params)
        p = self.params
        asserts.not_none(p.name)
        asserts.gt(p.num_groups, 0)
        asserts.gt(p.min_group_size, 0)
        asserts.le(p.min_group_size, p.dim)
        asserts.eq(p.dim % p.min_group_size, 0)

        if p.dim >= p.num_groups:
            asserts.eq(
                p.dim % p.num_groups,
                0,
                msg='p.dim({0}) is not dividable by p.num_groups({1})'.format(
                    p.dim, p.num_groups))

        asserts.in_set(p.input_rank, (3, 4))
Beispiel #5
0
    def _zoneout_internal(self, prev_v: JTensor, cur_v: JTensor,
                          padding_v: JTensor, zo_prob: float, is_eval: bool,
                          random_uniform: JTensor) -> JTensor:
        """Apply ZoneOut regularlization to cur_v.

    Implements ZoneOut regularization as described in
    https://arxiv.org/abs/1606.01305

    Args:
      prev_v: Values from the previous timestep.
      cur_v: Values from the current timestep.
      padding_v: The paddings vector for the cur timestep.
      zo_prob: Probability at which to apply ZoneOut regularization.
      is_eval: Whether or not in eval mode.
      random_uniform: Random uniform numbers. This can be None if zo_prob=0.0

    Returns:
      cur_v after ZoneOut regularization has been applied.
    """
        prev_v = jnp.array(prev_v)
        cur_v = jnp.array(cur_v)
        padding_v = jnp.array(padding_v)
        if zo_prob == 0.0:
            # Special case for when ZoneOut is not enabled.
            return jnp.where(padding_v, prev_v, cur_v)

        if is_eval:
            mix_prev = jnp.full(prev_v.shape, zo_prob) * prev_v
            mix_curr = jnp.full(cur_v.shape, 1.0 - zo_prob) * cur_v
            mix = mix_prev + mix_curr

            # If padding_v is 1, it always carries over the previous state.
            return jnp.where(padding_v, prev_v, mix)
        else:
            asserts.not_none(random_uniform)
            zo_p = (random_uniform < zo_prob).astype(padding_v.dtype)
            zo_p += padding_v
            # If padding_v is 1, we always carry over the previous state.
            zo_p = jnp.minimum(zo_p, 1.0)
            zo_p = jax.lax.stop_gradient(zo_p)
            return jnp.where(zo_p, prev_v, cur_v)
Beispiel #6
0
 def test_not_none_raises(self):
     value = None
     with self.assertRaisesRegex(ValueError,
                                 f'`value={value}` must not be `None`.$'):
         asserts.not_none(value)
     with self.assertRaisesRegex(
             ValueError, f'`custom_value={value}` must not be `None`.$'):
         asserts.not_none(value, value_str=f'custom_value={value}')
     custom_error_msg = 'This is a custom error message.'
     with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'):
         asserts.not_none(value, msg=custom_error_msg)
Beispiel #7
0
 def _compute_new_c(self, state0: NestedMap, i_i: JTensor, i_g: JTensor,
                    f_g: JTensor) -> JTensor:
     asserts.not_none(i_g)
     forget_gate = jax.nn.sigmoid(f_g) * state0.c
     input_gate = jax.nn.sigmoid(i_g) * jnp.tanh(i_i)
     return forget_gate + input_gate
Beispiel #8
0
 def test_not_none(self, value):
     asserts.not_none(value)
Beispiel #9
0
    def __init__(self, params: InstantiableParams) -> None:
        super().__init__(params)
        p = self.params
        asserts.in_set(
            p.layer_order,
            ['mhsa', 'conv', 'mhsa_before_conv', 'conv_before_mhsa'])

        if p.dropout_prob is not None:
            all_dropouts = [
                p.atten_dropout, p.atten_residual_dropout,
                p.conv_residual_dropout, p.ffn_residual_dropout,
                p.ffn_relu_dropout
            ]
            for prob in all_dropouts:
                assert prob is None or prob == p.dropout_prob

            p.atten_dropout = p.dropout_prob
            p.atten_residual_dropout = p.dropout_prob
            p.conv_residual_dropout = p.dropout_prob
            p.ffn_residual_dropout = p.dropout_prob
            p.ffn_relu_dropout = p.dropout_prob

        if p.fflayer_start_tpl:
            if p.input_dims == p.model_dims:
                fflayer_start_p = p.fflayer_start_tpl.Copy().Set(
                    name='fflayer_start',
                    activation=p.ff_activation,
                    input_dims=p.input_dims,
                    hidden_dims=p.model_dims * p.ffn_dim_multiplier,
                    residual_weight=p.ff_residual_weight,
                    residual_dropout_prob=p.ffn_residual_dropout,
                    relu_dropout_prob=p.ffn_relu_dropout,
                )
            else:
                # Need to add another projection layer in fflayer
                fflayer_start_p = p.fflayer_start_tpl.Copy().Set(
                    name='fflayer_start',
                    activation=p.ff_activation,
                    input_dims=p.input_dims,
                    output_dims=p.model_dims,
                    hidden_dims=p.model_dims * p.ffn_dim_multiplier,
                    residual_weight=p.ff_residual_weight,
                    residual_dropout_prob=p.ffn_residual_dropout,
                    relu_dropout_prob=p.ffn_relu_dropout,
                )
            self.create_child(fflayer_start_p.name, fflayer_start_p)

        if p.fflayer_end_tpl:
            fflayer_end_p = p.fflayer_end_tpl.Copy().Set(
                name='fflayer_end',
                activation=p.ff_activation,
                input_dims=p.model_dims,
                hidden_dims=p.model_dims * p.ffn_dim_multiplier,
                residual_weight=p.ff_residual_weight,
                residual_dropout_prob=p.ffn_residual_dropout,
                relu_dropout_prob=p.ffn_relu_dropout,
            )
            if not p.fflayer_weight_sharing:
                self.create_child(fflayer_end_p.name, fflayer_end_p)
            else:
                asserts.not_none(p.fflayer_start_tpl)

        if 'mhsa' in p.layer_order:
            trans_atten_p = p.trans_atten_tpl.Copy().Set(
                residual_dropout_prob=p.atten_residual_dropout,
                self_atten_tpl=p.trans_atten_tpl.self_atten_tpl.Copy().Set(
                    input_dim=p.model_dims,
                    hidden_dim=p.model_dims,
                    atten_dropout_prob=p.atten_dropout,
                    num_heads=p.atten_num_heads))
            if p.trans_atten_tpl.norm_tpl.cls == normalizations.LayerNorm:
                trans_atten_p.norm_tpl = trans_atten_p.norm_tpl.Copy().Set(
                    input_dims=p.model_dims)
            else:
                trans_atten_p.norm_tpl = trans_atten_p.norm_tpl.Copy().Set(
                    dim=p.model_dims)
            self.create_child('trans_atten', trans_atten_p)

        if 'conv' in p.layer_order:
            lconv_p = p.lconv_tpl.Copy().Set(
                input_dims=p.model_dims,
                kernel_size=p.kernel_size,
                dropout_prob=p.conv_residual_dropout)
            self.create_child('lconv', lconv_p)

        ln_p = p.final_ln_tpl.Copy().Set(name='final_ln',
                                         input_dims=p.model_dims)
        self.create_child('final_ln', ln_p)