def Task(cls): return base_config.SetupRNMTParams( model.RNMTModel.Params(), name='punctuator_rnmt', vocab_size=cls._VOCAB_SIZE, embedding_dim=1024, hidden_dim=1024, num_heads=4, num_encoder_layers=6, num_decoder_layers=8, learning_rate=1e-4, l2_regularizer_weight=1e-5, lr_warmup_steps=500, lr_decay_start=400000, lr_decay_end=1200000, lr_min=0.5, ls_uncertainty=0.1, atten_dropout_prob=0.3, residual_dropout_prob=0.3, adam_beta2=0.98, adam_epsilon=1e-6, )
def Task(cls): p = base_config.SetupRNMTParams( name='wmt14_en_de_rnmtplus_base', vocab_size=cls.VOCAB_SIZE, embedding_dim=1024, hidden_dim=1024, num_heads=4, num_encoder_layers=6, num_decoder_layers=8, learning_rate=1e-4, l2_regularizer_weight=1e-5, lr_warmup_steps=500, lr_decay_start=400000, lr_decay_end=1200000, lr_min=0.5, ls_uncertainty=0.1, atten_dropout_prob=0.3, residual_dropout_prob=0.3, adam_beta2=0.98, adam_epsilon=1e-6, ) p.eval.samples_per_summary = 7500 return p