def _ConfigSelfAttenParams(self, trans_atten_p): p = self.params if not p.relative_pos_emb_dim: p.relative_pos_emb_dim = p.input_dim # TODO(jamesqin): add an attention factory in batch_major_attention. if not _AttenCtxIsSet(p.atten_left_context) and not _AttenCtxIsSet( p.atten_right_context): # No atten context set, each position attends to all positions. atten_type = 'global' if not p.use_relative_atten else 'global_relative' elif not _AttenCtxIsSet( p.atten_left_context) and p.atten_right_context == 0: # Left context is infinite, right context is 0. assert not p.use_relative_atten, ( 'Relative attention isn\'t supported for causal attention.') atten_type = 'global_causal' else: atten_type = 'local_relative' if p.use_relative_atten else 'local' if atten_type == 'global_relative': atten_tpl = ( attention_lib.MultiHeadedAttentionXL.Params().Set( rel_pos_emb_dim=p.relative_pos_emb_dim)) hparams_lib.CopyFieldsTo( p.trans_atten_tpl.atten_tpl, atten_tpl, skip='rel_pos_emb_dim') elif atten_type == 'local_relative': atten_tpl = attention_lib.LocalSelfAttentionXL.Params().Set( left_context=p.atten_left_context, right_context=p.atten_right_context, rel_pos_emb_dim=p.relative_pos_emb_dim) hparams_lib.CopyFieldsTo( p.trans_atten_tpl.atten_tpl, atten_tpl, skip=['left_context', 'right_context', 'rel_pos_emb_dim']) elif atten_type == 'local': atten_tpl = attention_lib.LocalSelfAttention.Params().Set( left_context=p.atten_left_context, right_context=p.atten_right_context) hparams_lib.CopyFieldsTo( p.trans_atten_tpl.atten_tpl, atten_tpl, skip=['left_context', 'right_context']) else: # No op for 'global' atten assert atten_type in ('global', 'global_causal'), ( f'Unknown atten_type {atten_type}') atten_tpl = attention_lib.MultiHeadedAttention.Params() hparams_lib.CopyFieldsTo(trans_atten_p.atten_tpl, atten_tpl) trans_atten_p.atten_tpl = atten_tpl
def testCopyFieldsToMissingKeyInDest(self): source = hyperparams.Params() dest = hyperparams.Params() source.Define('a', 'a', '') dest.Define('b', 'b', '') with self.assertRaises(AttributeError): hyperparams.CopyFieldsTo(source, dest)
def testCopyFieldsToMoreKeyInDest(self): source = hyperparams.Params() dest = hyperparams.Params() source.Define('b', 'b', '') dest.Define('b', '', '') dest.Define('a', 'a', '') hyperparams.CopyFieldsTo(source, dest) self.assertEqual(dest.b, source.b) self.assertEqual('a', dest.a)
def testCopyFieldsTo(self): source = hyperparams.Params() dest = hyperparams.Params() source.Define('a', 'a', '') source.Define('b', 'b', '') source.Define('c', 'c', '') dest.Define('a', '', '') hyperparams.CopyFieldsTo(source, dest, skip=['b', 'c']) self.assertEqual(source.a, dest.a) self.assertNotIn('b', dest) self.assertNotIn('c', dest)
def testCopyFieldsToDifferentKeysWithMerge(self): source = hyperparams.Params() dest = hyperparams.Params() source.Define('a', 'a', '') source.Define('c', 'c', '') dest.Define('a', '', '') dest.Define('b', 'b', '') hyperparams.CopyFieldsTo(source, dest, ignore_unknown_keys=True) self.assertEqual(source.a, dest.a) self.assertEqual('b', dest.b) self.assertNotIn('c', dest)
def __init__(self, params): super().__init__(params) p = self.params ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim) self.CreateChild('ln', ln_p) linear_start_p = p.linear_start_tpl.Copy().Set(name='linear_start', input_dim=p.input_dim, output_dim=2 * p.input_dim) linear_end_p = p.linear_end_tpl.Copy().Set(name='linear_end', input_dim=p.input_dim, output_dim=p.input_dim) self.CreateChild('linear_start', linear_start_p) self.CreateChild('linear_end', linear_end_p) if p.conv_norm_layer_tpl.cls == layers.LayerNorm: norm_p = p.conv_norm_layer_tpl.Copy().Set(name='norm_layer', input_dim=p.input_dim) else: norm_p = p.conv_norm_layer_tpl.Copy().Set(name='norm_layer', dim=p.input_dim) if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer: norm_p.cumulative = p.is_causal self.CreateChild('norm', norm_p) if (p.is_causal and p.depthwise_conv_tpl.cls == conv_layers_with_time_padding.DepthwiseConv2DLayer): # If causal, switch to causal depthwise conv. depthwise_conv_p = (conv_layers_with_time_padding. CausalDepthwiseConv2DLayer.Params()) hyperparams.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p) else: depthwise_conv_p = p.depthwise_conv_tpl.Copy() # 1d depthwise conv with channel_mulitplier = 1 depthwise_conv_p.Set(name='depthwise_conv', filter_shape=(p.kernel_size, 1, p.input_dim, 1), filter_stride=(1, 1)) self.CreateChild('depthwise_conv1d', depthwise_conv_p) dropout_p = p.dropout_tpl.Copy().Set(name='dropout', keep_prob=1. - p.dropout_prob) self.CreateChild('dropout', dropout_p)
def testCopyFieldsToDoesNotCopyClass(self): source = hyperparams.InstantiableParams(hyperparams.Params) dest = hyperparams.InstantiableParams(hyperparams.InstantiableParams) hyperparams.CopyFieldsTo(source, dest) self.assertEqual(dest.cls, hyperparams.InstantiableParams)
def testCopyFieldsToDoesNotCopyClass(self): source = _params.InstantiableParams(cls=_params.Params) dest = _params.InstantiableParams(cls=_params.InstantiableParams) _params.CopyFieldsTo(source, dest) self.assertEqual(dest.cls, _params.InstantiableParams)
def __init__(self, params): super().__init__(params) p = self.params ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim) self.CreateChild('ln', ln_p) if p.split_act_gated_linear_start: linear_start_act_p = p.linear_start_tpl.Copy().Set( input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.df, activation_split_dims_mapping=p.activation_split_dims_mapping.blf) linear_start_gated_p = p.linear_start_tpl.Copy().Set( input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.df, activation_split_dims_mapping=p.activation_split_dims_mapping.blf) self.CreateChild('linear_start_act', linear_start_act_p) self.CreateChild('linear_start_gated', linear_start_gated_p) else: linear_start_p = p.linear_start_tpl.Copy().Set( name='linear_start', input_dim=p.input_dim, output_dim=2 * p.input_dim) self.CreateChild('linear_start', linear_start_p) linear_end_p = p.linear_end_tpl.Copy().Set( name='linear_end', input_dim=p.input_dim, output_dim=p.input_dim, device_mesh=p.device_mesh, weight_split_dims_mapping=p.weight_split_dims_mapping.fd, activation_split_dims_mapping=p.activation_split_dims_mapping.bld) self.CreateChild('linear_end', linear_end_p) if p.conv_norm_layer_tpl.cls == layers.LayerNorm: norm_p = p.conv_norm_layer_tpl.Copy().Set( name='norm_layer', input_dim=p.input_dim) else: norm_p = p.conv_norm_layer_tpl.Copy().Set( name='norm_layer', dim=p.input_dim) if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer: norm_p.cumulative = p.is_causal self.CreateChild('norm', norm_p) if (p.is_causal and p.depthwise_conv_tpl.cls == conv_layers_with_time_padding.DepthwiseConv2DLayer): # If causal, switch to causal depthwise conv. depthwise_conv_p = ( conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params()) hparams_lib.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p) else: depthwise_conv_p = p.depthwise_conv_tpl.Copy() if issubclass(depthwise_conv_p.cls, conv_layers_with_time_padding.DepthwiseConv2DLayer): depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, 1) else: depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, p.input_dim) # 1d depthwise conv with channel_mulitplier = 1 depthwise_conv_p.Set( name='depthwise_conv', filter_stride=(1, 1)) self.CreateChild('depthwise_conv1d', depthwise_conv_p) dropout_p = p.dropout_tpl.Copy().Set( name='dropout', keep_prob=1. - p.dropout_prob) self.CreateChild('dropout', dropout_p)