예제 #1
0
  def _TransformerParamsWithEmbeddings(self,
                                       num_decoder_layers=0,
                                       num_encoder_layers=4,
                                       splits=1,
                                       num_micro_batches=1):
    model_dim = 4
    params = GPipeTransformerStack.Params()
    params.name = 'transformer'
    params.model_dim = model_dim
    params.num_decoder_layers = num_decoder_layers
    params.decoder_tpl.tr_atten_tpl.num_attention_heads = 1
    params.decoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
    params.num_encoder_layers = num_encoder_layers
    params.encoder_tpl.tr_atten_tpl.num_attention_heads = 1
    params.encoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
    params.num_micro_batches = num_micro_batches
    params.use_pipelined_embeddings = True
    params.state_dtype = tf.float32

    emb_params = params.emb_tpl
    # Default config for the token embedding.
    emb_params.token_emb.use_matmul = True
    emb_params.token_emb.use_3d_weight_tensor = False
    emb_params.token_emb.vocab_size = 10
    emb_params.token_emb.embedding_dim = model_dim

    # Default config for the position embedding.
    emb_params.position_emb.embedding_dim = model_dim
    emb_params.position_emb.trainable_scaling = False
    params.splits = splits
    params.random_seed = 0
    return params
예제 #2
0
def _TransformerParamsWithEmbeddings(num_decoder_layers=0,
                                     num_encoder_layers=4,
                                     splits=1,
                                     num_micro_batches=1,
                                     has_softmax=False,
                                     use_task_ids=False):
    model_dim = 4
    params = GPipeTransformerStack.Params()
    params.name = 'transformer'
    params.model_dim = model_dim
    params.num_decoder_layers = num_decoder_layers
    params.decoder_tpl.source_dim = model_dim
    params.decoder_tpl.tr_atten_tpl.num_attention_heads = 1
    params.decoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
    params.num_encoder_layers = num_encoder_layers
    params.encoder_tpl.source_dim = model_dim
    params.encoder_tpl.tr_atten_tpl.num_attention_heads = 1
    params.encoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
    params.num_micro_batches = num_micro_batches
    params.state_dtype = tf.float32
    if has_softmax:
        params.softmax_tpl.input_dim = model_dim
        params.softmax_tpl.num_classes = 2
    else:
        params.softmax_tpl = None

    emb_params = params.emb_tpl
    # Default config for the token embedding.
    emb_params.token_emb.use_matmul = True
    emb_params.token_emb.use_3d_weight_tensor = False
    emb_params.token_emb.vocab_size = 10
    emb_params.token_emb.embedding_dim = model_dim

    # Default config for the position embedding.
    emb_params.position_emb.embedding_dim = model_dim
    emb_params.position_emb.trainable_scaling = False

    # Task embeddings.
    if use_task_ids:
        emb_params.enc_task_emb = emb_params.token_emb.Copy()
        emb_params.dec_task_emb = emb_params.token_emb.Copy()
    params.splits = splits
    params.random_seed = 0
    return params
예제 #3
0
 def _TransformerParams(self,
                        num_decoder_layers=0,
                        num_encoder_layers=4,
                        splits=1,
                        num_micro_batches=1):
     model_dim = 2
     params = GPipeTransformerStack.Params()
     params.name = 'transformer'
     params.model_dim = model_dim
     params.num_decoder_layers = num_decoder_layers
     params.decoder_tpl.tr_atten_tpl.num_attention_heads = 1
     params.decoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
     params.num_encoder_layers = num_encoder_layers
     params.encoder_tpl.tr_atten_tpl.num_attention_heads = 1
     params.encoder_tpl.tr_fflayer_tpl.hidden_dim = model_dim
     params.num_micro_batches = num_micro_batches
     params.splits = splits
     params.random_seed = 0
     return params