def __init__(self, params: InstantiableParams) -> None: super().__init__(params) p = self.params asserts.not_none(p.self_atten_tpl) self.create_child('self_atten', p.self_atten_tpl) self.create_child('norm', p.norm_tpl) # Initialize residual dropout. params = p.residual_dropout_tpl.Copy() params.keep_prob = (1.0 - p.residual_dropout_prob) self.create_child('residual_dropout', params)
def __init__(self, params: InstantiableParams) -> None: """Constructor for the learner.""" assert params.name, ('Learner params for %s must have a "name"' % self.__class__.__name__) module_name = params.name NestedMap.CheckKey(module_name) self._params = params.Copy() p = self.params asserts.not_none(p.optimizer) asserts.not_none(p.loss_name) self._optimizer = p.optimizer.Instantiate() self._grad_tx = self.optimizer.get_grad_transformation()
def __init__(self, params: InstantiableParams) -> None: """Constructor for the MultiOptimizer learner.""" super().__init__(params) p = self.params asserts.not_none(p.optimizer) asserts.not_none(p.loss_name) if len(p.auxiliary_optimizers) != len(p.auxiliary_regex): raise ValueError('The length of the auxiliary regex must match the length' 'of the auxiliary optimizers.') self._optimizer = p.optimizer.Instantiate() self._auxiliary_optimizers = [ opt.Instantiate() for opt in p.auxiliary_optimizers ] self._grad_tx_fn = self._optimizer.get_grad_transformation self._auxiliary_grad_tx_fn = [ opt.get_grad_transformation for opt in self._auxiliary_optimizers ]
def __init__(self, params): """Initializes GroupNorm layer and checks parameters.""" super().__init__(params) p = self.params asserts.not_none(p.name) asserts.gt(p.num_groups, 0) asserts.gt(p.min_group_size, 0) asserts.le(p.min_group_size, p.dim) asserts.eq(p.dim % p.min_group_size, 0) if p.dim >= p.num_groups: asserts.eq( p.dim % p.num_groups, 0, msg='p.dim({0}) is not dividable by p.num_groups({1})'.format( p.dim, p.num_groups)) asserts.in_set(p.input_rank, (3, 4))
def _zoneout_internal(self, prev_v: JTensor, cur_v: JTensor, padding_v: JTensor, zo_prob: float, is_eval: bool, random_uniform: JTensor) -> JTensor: """Apply ZoneOut regularlization to cur_v. Implements ZoneOut regularization as described in https://arxiv.org/abs/1606.01305 Args: prev_v: Values from the previous timestep. cur_v: Values from the current timestep. padding_v: The paddings vector for the cur timestep. zo_prob: Probability at which to apply ZoneOut regularization. is_eval: Whether or not in eval mode. random_uniform: Random uniform numbers. This can be None if zo_prob=0.0 Returns: cur_v after ZoneOut regularization has been applied. """ prev_v = jnp.array(prev_v) cur_v = jnp.array(cur_v) padding_v = jnp.array(padding_v) if zo_prob == 0.0: # Special case for when ZoneOut is not enabled. return jnp.where(padding_v, prev_v, cur_v) if is_eval: mix_prev = jnp.full(prev_v.shape, zo_prob) * prev_v mix_curr = jnp.full(cur_v.shape, 1.0 - zo_prob) * cur_v mix = mix_prev + mix_curr # If padding_v is 1, it always carries over the previous state. return jnp.where(padding_v, prev_v, mix) else: asserts.not_none(random_uniform) zo_p = (random_uniform < zo_prob).astype(padding_v.dtype) zo_p += padding_v # If padding_v is 1, we always carry over the previous state. zo_p = jnp.minimum(zo_p, 1.0) zo_p = jax.lax.stop_gradient(zo_p) return jnp.where(zo_p, prev_v, cur_v)
def test_not_none_raises(self): value = None with self.assertRaisesRegex(ValueError, f'`value={value}` must not be `None`.$'): asserts.not_none(value) with self.assertRaisesRegex( ValueError, f'`custom_value={value}` must not be `None`.$'): asserts.not_none(value, value_str=f'custom_value={value}') custom_error_msg = 'This is a custom error message.' with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'): asserts.not_none(value, msg=custom_error_msg)
def _compute_new_c(self, state0: NestedMap, i_i: JTensor, i_g: JTensor, f_g: JTensor) -> JTensor: asserts.not_none(i_g) forget_gate = jax.nn.sigmoid(f_g) * state0.c input_gate = jax.nn.sigmoid(i_g) * jnp.tanh(i_i) return forget_gate + input_gate
def test_not_none(self, value): asserts.not_none(value)
def __init__(self, params: InstantiableParams) -> None: super().__init__(params) p = self.params asserts.in_set( p.layer_order, ['mhsa', 'conv', 'mhsa_before_conv', 'conv_before_mhsa']) if p.dropout_prob is not None: all_dropouts = [ p.atten_dropout, p.atten_residual_dropout, p.conv_residual_dropout, p.ffn_residual_dropout, p.ffn_relu_dropout ] for prob in all_dropouts: assert prob is None or prob == p.dropout_prob p.atten_dropout = p.dropout_prob p.atten_residual_dropout = p.dropout_prob p.conv_residual_dropout = p.dropout_prob p.ffn_residual_dropout = p.dropout_prob p.ffn_relu_dropout = p.dropout_prob if p.fflayer_start_tpl: if p.input_dims == p.model_dims: fflayer_start_p = p.fflayer_start_tpl.Copy().Set( name='fflayer_start', activation=p.ff_activation, input_dims=p.input_dims, hidden_dims=p.model_dims * p.ffn_dim_multiplier, residual_weight=p.ff_residual_weight, residual_dropout_prob=p.ffn_residual_dropout, relu_dropout_prob=p.ffn_relu_dropout, ) else: # Need to add another projection layer in fflayer fflayer_start_p = p.fflayer_start_tpl.Copy().Set( name='fflayer_start', activation=p.ff_activation, input_dims=p.input_dims, output_dims=p.model_dims, hidden_dims=p.model_dims * p.ffn_dim_multiplier, residual_weight=p.ff_residual_weight, residual_dropout_prob=p.ffn_residual_dropout, relu_dropout_prob=p.ffn_relu_dropout, ) self.create_child(fflayer_start_p.name, fflayer_start_p) if p.fflayer_end_tpl: fflayer_end_p = p.fflayer_end_tpl.Copy().Set( name='fflayer_end', activation=p.ff_activation, input_dims=p.model_dims, hidden_dims=p.model_dims * p.ffn_dim_multiplier, residual_weight=p.ff_residual_weight, residual_dropout_prob=p.ffn_residual_dropout, relu_dropout_prob=p.ffn_relu_dropout, ) if not p.fflayer_weight_sharing: self.create_child(fflayer_end_p.name, fflayer_end_p) else: asserts.not_none(p.fflayer_start_tpl) if 'mhsa' in p.layer_order: trans_atten_p = p.trans_atten_tpl.Copy().Set( residual_dropout_prob=p.atten_residual_dropout, self_atten_tpl=p.trans_atten_tpl.self_atten_tpl.Copy().Set( input_dim=p.model_dims, hidden_dim=p.model_dims, atten_dropout_prob=p.atten_dropout, num_heads=p.atten_num_heads)) if p.trans_atten_tpl.norm_tpl.cls == normalizations.LayerNorm: trans_atten_p.norm_tpl = trans_atten_p.norm_tpl.Copy().Set( input_dims=p.model_dims) else: trans_atten_p.norm_tpl = trans_atten_p.norm_tpl.Copy().Set( dim=p.model_dims) self.create_child('trans_atten', trans_atten_p) if 'conv' in p.layer_order: lconv_p = p.lconv_tpl.Copy().Set( input_dims=p.model_dims, kernel_size=p.kernel_size, dropout_prob=p.conv_residual_dropout) self.create_child('lconv', lconv_p) ln_p = p.final_ln_tpl.Copy().Set(name='final_ln', input_dims=p.model_dims) self.create_child('final_ln', ln_p)