Exemplo n.º 1
0
 def test_decoder_dp(self):
     global _global_parallel_strategy
     _global_parallel_strategy = "dp"
     global _global_process_mesh
     _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
     train_program = static.Program()
     start_program = static.Program()
     dist_context = DistributedContext()
     train_program, start_program = decoder_pretrain_forward(
         train_program, start_program)
     completer = Completer(dist_context)
     complete_train_program = completer.complete_forward_annotation(
         train_program)
     self.assertTrue(dist_context.validate_dist_attr_for_program())
 def test_attn_dp(self):
     global _global_parallel_strategy
     _global_parallel_strategy = "dp"
     global _global_process_mesh
     _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
     train_program = static.Program()
     start_program = static.Program()
     dist_context = DistributedContext()
     train_program, start_program = attn_pretrain_forward(
         train_program, start_program)
     complete_train_program = auto.complete_annotation(
         train_program, dist_context)
     # print_program_with_dist_attr(complete_train_program,
     #                                     dist_context)
     self.assertTrue(dist_context.validate_dist_attr_for_program())