コード例 #1
0
def multilingual_base_architecture(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
    args.share_encoder_input_output_embed = getattr(
        args, 'share_encoder_input_output_embed', True)
    args.no_token_positional_embeddings = getattr(
        args, 'no_token_positional_embeddings', False)
    args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
    args.num_segment = getattr(args, 'num_segment', 1)

    args.encoder_layers = getattr(args, 'encoder_layers', 6)

    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
    args.bias_kv = getattr(args, 'bias_kv', False)
    args.zero_attn = getattr(args, 'zero_attn', False)

    args.sent_loss = getattr(args, 'sent_loss', False)

    args.activation_fn = getattr(args, 'activation_fn', 'gelu')
    args.encoder_normalize_before = getattr(args, 'encoder_normalize_before',
                                            False)
    args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
    args.apply_bert_init = getattr(args, 'apply_bert_init', True)
    transformer_base_architecture(args)
    args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings',
                                            True)
    args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings',
                                            True)
    args.share_encoders = getattr(args, 'share_encoders', True)
    args.share_decoders = getattr(args, 'share_decoders', True)
コード例 #2
0
def transformer_xlm_iwslt_decoder(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    transformer_base_architecture(args)
コード例 #3
0
def base_architecture(args):
    transformer_base_architecture(args)
コード例 #4
0
def transformer_iwslt_16(args):
    for_iwslt_16(args)
    transformer_base_architecture(args)