def _ConfigSelfAttenParams(self, trans_atten_p):
    p = self.params
    if not p.relative_pos_emb_dim:
      p.relative_pos_emb_dim = p.input_dim

    # TODO(jamesqin): add an attention factory in batch_major_attention.
    if not _AttenCtxIsSet(p.atten_left_context) and not _AttenCtxIsSet(
        p.atten_right_context):
      # No atten context set, each position attends to all positions.
      atten_type = 'global' if not p.use_relative_atten else 'global_relative'
    elif not _AttenCtxIsSet(
        p.atten_left_context) and p.atten_right_context == 0:
      # Left context is infinite, right context is 0.
      assert not p.use_relative_atten, (
          'Relative attention isn\'t supported for causal attention.')
      atten_type = 'global_causal'
    else:
      atten_type = 'local_relative' if p.use_relative_atten else 'local'

    if atten_type == 'global_relative':
      atten_tpl = (
          attention_lib.MultiHeadedAttentionXL.Params().Set(
              rel_pos_emb_dim=p.relative_pos_emb_dim))
      hparams_lib.CopyFieldsTo(
          p.trans_atten_tpl.atten_tpl, atten_tpl, skip='rel_pos_emb_dim')
    elif atten_type == 'local_relative':
      atten_tpl = attention_lib.LocalSelfAttentionXL.Params().Set(
          left_context=p.atten_left_context,
          right_context=p.atten_right_context,
          rel_pos_emb_dim=p.relative_pos_emb_dim)
      hparams_lib.CopyFieldsTo(
          p.trans_atten_tpl.atten_tpl,
          atten_tpl,
          skip=['left_context', 'right_context', 'rel_pos_emb_dim'])
    elif atten_type == 'local':
      atten_tpl = attention_lib.LocalSelfAttention.Params().Set(
          left_context=p.atten_left_context,
          right_context=p.atten_right_context)
      hparams_lib.CopyFieldsTo(
          p.trans_atten_tpl.atten_tpl,
          atten_tpl,
          skip=['left_context', 'right_context'])
    else:
      # No op for 'global' atten
      assert atten_type in ('global', 'global_causal'), (
          f'Unknown atten_type {atten_type}')
      atten_tpl = attention_lib.MultiHeadedAttention.Params()
      hparams_lib.CopyFieldsTo(trans_atten_p.atten_tpl, atten_tpl)
    trans_atten_p.atten_tpl = atten_tpl
Exemple #2
0
 def testCopyFieldsToMissingKeyInDest(self):
     source = hyperparams.Params()
     dest = hyperparams.Params()
     source.Define('a', 'a', '')
     dest.Define('b', 'b', '')
     with self.assertRaises(AttributeError):
         hyperparams.CopyFieldsTo(source, dest)
Exemple #3
0
 def testCopyFieldsToMoreKeyInDest(self):
     source = hyperparams.Params()
     dest = hyperparams.Params()
     source.Define('b', 'b', '')
     dest.Define('b', '', '')
     dest.Define('a', 'a', '')
     hyperparams.CopyFieldsTo(source, dest)
     self.assertEqual(dest.b, source.b)
     self.assertEqual('a', dest.a)
Exemple #4
0
 def testCopyFieldsTo(self):
   source = hyperparams.Params()
   dest = hyperparams.Params()
   source.Define('a', 'a', '')
   source.Define('b', 'b', '')
   source.Define('c', 'c', '')
   dest.Define('a', '', '')
   hyperparams.CopyFieldsTo(source, dest, skip=['b', 'c'])
   self.assertEqual(source.a, dest.a)
   self.assertNotIn('b', dest)
   self.assertNotIn('c', dest)
Exemple #5
0
 def testCopyFieldsToDifferentKeysWithMerge(self):
     source = hyperparams.Params()
     dest = hyperparams.Params()
     source.Define('a', 'a', '')
     source.Define('c', 'c', '')
     dest.Define('a', '', '')
     dest.Define('b', 'b', '')
     hyperparams.CopyFieldsTo(source, dest, ignore_unknown_keys=True)
     self.assertEqual(source.a, dest.a)
     self.assertEqual('b', dest.b)
     self.assertNotIn('c', dest)
Exemple #6
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params

        ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim)
        self.CreateChild('ln', ln_p)

        linear_start_p = p.linear_start_tpl.Copy().Set(name='linear_start',
                                                       input_dim=p.input_dim,
                                                       output_dim=2 *
                                                       p.input_dim)
        linear_end_p = p.linear_end_tpl.Copy().Set(name='linear_end',
                                                   input_dim=p.input_dim,
                                                   output_dim=p.input_dim)
        self.CreateChild('linear_start', linear_start_p)
        self.CreateChild('linear_end', linear_end_p)

        if p.conv_norm_layer_tpl.cls == layers.LayerNorm:
            norm_p = p.conv_norm_layer_tpl.Copy().Set(name='norm_layer',
                                                      input_dim=p.input_dim)
        else:
            norm_p = p.conv_norm_layer_tpl.Copy().Set(name='norm_layer',
                                                      dim=p.input_dim)
        if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer:
            norm_p.cumulative = p.is_causal
        self.CreateChild('norm', norm_p)

        if (p.is_causal and p.depthwise_conv_tpl.cls
                == conv_layers_with_time_padding.DepthwiseConv2DLayer):
            # If causal, switch to causal depthwise conv.
            depthwise_conv_p = (conv_layers_with_time_padding.
                                CausalDepthwiseConv2DLayer.Params())
            hyperparams.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p)
        else:
            depthwise_conv_p = p.depthwise_conv_tpl.Copy()
        # 1d depthwise conv with channel_mulitplier = 1
        depthwise_conv_p.Set(name='depthwise_conv',
                             filter_shape=(p.kernel_size, 1, p.input_dim, 1),
                             filter_stride=(1, 1))
        self.CreateChild('depthwise_conv1d', depthwise_conv_p)

        dropout_p = p.dropout_tpl.Copy().Set(name='dropout',
                                             keep_prob=1. - p.dropout_prob)
        self.CreateChild('dropout', dropout_p)
Exemple #7
0
 def testCopyFieldsToDoesNotCopyClass(self):
   source = hyperparams.InstantiableParams(hyperparams.Params)
   dest = hyperparams.InstantiableParams(hyperparams.InstantiableParams)
   hyperparams.CopyFieldsTo(source, dest)
   self.assertEqual(dest.cls, hyperparams.InstantiableParams)
 def testCopyFieldsToDoesNotCopyClass(self):
     source = _params.InstantiableParams(cls=_params.Params)
     dest = _params.InstantiableParams(cls=_params.InstantiableParams)
     _params.CopyFieldsTo(source, dest)
     self.assertEqual(dest.cls, _params.InstantiableParams)
Exemple #9
0
  def __init__(self, params):
    super().__init__(params)
    p = self.params

    ln_p = p.ln_tpl.Copy().Set(name='ln', input_dim=p.input_dim)
    self.CreateChild('ln', ln_p)

    if p.split_act_gated_linear_start:
      linear_start_act_p = p.linear_start_tpl.Copy().Set(
          input_dim=p.input_dim,
          output_dim=p.input_dim,
          device_mesh=p.device_mesh,
          weight_split_dims_mapping=p.weight_split_dims_mapping.df,
          activation_split_dims_mapping=p.activation_split_dims_mapping.blf)
      linear_start_gated_p = p.linear_start_tpl.Copy().Set(
          input_dim=p.input_dim,
          output_dim=p.input_dim,
          device_mesh=p.device_mesh,
          weight_split_dims_mapping=p.weight_split_dims_mapping.df,
          activation_split_dims_mapping=p.activation_split_dims_mapping.blf)
      self.CreateChild('linear_start_act', linear_start_act_p)
      self.CreateChild('linear_start_gated', linear_start_gated_p)
    else:
      linear_start_p = p.linear_start_tpl.Copy().Set(
          name='linear_start',
          input_dim=p.input_dim,
          output_dim=2 * p.input_dim)
      self.CreateChild('linear_start', linear_start_p)

    linear_end_p = p.linear_end_tpl.Copy().Set(
        name='linear_end',
        input_dim=p.input_dim,
        output_dim=p.input_dim,
        device_mesh=p.device_mesh,
        weight_split_dims_mapping=p.weight_split_dims_mapping.fd,
        activation_split_dims_mapping=p.activation_split_dims_mapping.bld)
    self.CreateChild('linear_end', linear_end_p)

    if p.conv_norm_layer_tpl.cls == layers.LayerNorm:
      norm_p = p.conv_norm_layer_tpl.Copy().Set(
          name='norm_layer', input_dim=p.input_dim)
    else:
      norm_p = p.conv_norm_layer_tpl.Copy().Set(
          name='norm_layer', dim=p.input_dim)
    if p.conv_norm_layer_tpl.cls == bn_layers.GroupNormLayer:
      norm_p.cumulative = p.is_causal
    self.CreateChild('norm', norm_p)

    if (p.is_causal and p.depthwise_conv_tpl.cls ==
        conv_layers_with_time_padding.DepthwiseConv2DLayer):
      # If causal, switch to causal depthwise conv.
      depthwise_conv_p = (
          conv_layers_with_time_padding.CausalDepthwiseConv2DLayer.Params())
      hparams_lib.CopyFieldsTo(p.depthwise_conv_tpl, depthwise_conv_p)
    else:
      depthwise_conv_p = p.depthwise_conv_tpl.Copy()

    if issubclass(depthwise_conv_p.cls,
                  conv_layers_with_time_padding.DepthwiseConv2DLayer):
      depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim, 1)
    else:
      depthwise_conv_p.filter_shape = (p.kernel_size, 1, p.input_dim,
                                       p.input_dim)
    # 1d depthwise conv with channel_mulitplier = 1
    depthwise_conv_p.Set(
        name='depthwise_conv',
        filter_stride=(1, 1))
    self.CreateChild('depthwise_conv1d', depthwise_conv_p)

    dropout_p = p.dropout_tpl.Copy().Set(
        name='dropout', keep_prob=1. - p.dropout_prob)
    self.CreateChild('dropout', dropout_p)