def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                     user_defined_strategy):
     super(PipelineOptimizer,
           self)._set_basic_info(loss, role_maker, user_defined_optimizer,
                                 user_defined_strategy)
     num_microbatches = user_defined_strategy.pipeline_configs[
         'micro_batch']
     self.wrapped_opt = PO(self.inner_opt,
                           num_microbatches=num_microbatches)
示例#2
0
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
        self.local_rank = self._get_local_rank(current_endpoint, endpoints)
        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches,
                              start_cpu_core_id=self.local_rank)
        node_num = _get_node_num(endpoints)
        gpus_per_node = len(endpoints) // node_num
        self.startup_program = startup_program
        self.local_rank = self._get_local_rank(current_endpoint, endpoints)
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

        loss.block.program._pipeline_opt = dict()
        loss.block.program._pipeline_opt['local_rank'] = self.local_rank
        optimize_ops, params_grads, prog_list = \
            self.wrapped_opt.minimize(loss, startup_program,
                                      parameter_list, no_grad_set)

        assert prog_list
        self.main_program_list = prog_list
        self.main_program = loss.block.program
        self.inner_parallelism = loss.block.program._pipeline_opt[
            'inner_parallelism']
        nranks = len(endpoints)
        self.nranks = nranks
        self.nrings = len(self.main_program_list)

        self.rank = self.role_maker._worker_index()
        self.endpoints = endpoints
        self.current_endpoint = current_endpoint

        pipeline_helper = PipelineHelper(self.role_maker)
        pipeline_helper.update_startup_program(
            self.startup_program._pipeline_opt["startup_program"],
            self.inner_parallelism)

        self._transpile_main_program(loss, node_num, gpus_per_node)
        return optimize_ops, params_grads
示例#3
0
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches)
        node_num = _get_node_num(endpoints)
        gpus_per_node = len(endpoints) // node_num
        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        assert self.nranks % node_num == 0

        loss.block.program._pipeline_opt = dict()
        loss.block.program._pipeline_opt['local_rank'] = self.rank
        loss.block.program._pipeline_opt[
            'micro_batch_size'] = self.micro_batch_size
        loss.block.program._pipeline_opt['schedule_mode'] = self.schedule_mode
        optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set)
        assert prog_list

        self.main_program_list = prog_list
        self.main_program = loss.block.program
        self.inner_parallelism = loss.block.program._pipeline_opt[
            'inner_parallelism']
        assert self.nranks % self.inner_parallelism == 0

        pipeline_helper = PipelineHelper(self.role_maker)
        pipeline_helper.update_startup_program(
            self.startup_program._pipeline_opt["startup_program"],
            self.inner_parallelism)

        pipeline_num = self.nranks // self.inner_parallelism
        self._transpile_main_program(loss, pipeline_num,
                                     self.inner_parallelism)
        return optimize_ops, params_grads
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()

        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches)
        orig_startup_program = startup_program if startup_program else fluid.default_startup_program(
        )
        block = loss.block
        program = block.program

        program._pipeline_opt = dict()
        program._pipeline_opt['local_rank'] = self.rank
        program._pipeline_opt['global_ring_id'] = self.global_ring_id
        program._pipeline_opt['ring_id'] = self.start_pipeline_ring_id
        program._pipeline_opt['micro_batch_size'] = self.micro_batch_size
        program._pipeline_opt['schedule_mode'] = self.schedule_mode
        program._pipeline_opt['use_sharding'] = False
        program._pipeline_opt['mp_degree'] = 1
        program._pipeline_opt['mp_rank'] = 0
        optimize_ops, params_grads, prog_list, pp_pair, ring_map = self.wrapped_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set)
        self.startup_program = orig_startup_program._pipeline_opt[
            'startup_program']
        self.inner_parallelism = program._pipeline_opt['inner_parallelism']
        assert self.nranks % self.inner_parallelism == 0
        assert prog_list
        self.pipeline_num = len(self.endpoints) // self.inner_parallelism

        self._init_process_group(pp_pair, ring_map)

        self.main_program_list = prog_list
        self.main_program = program
        if self.pipeline_num > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads
示例#5
0
 def get_optimizer_dygraph(self, parameter_list):
     optimizer = fluid.optimizer.SGD(learning_rate=0.5,
                                     parameter_list=parameter_list)
     optimizer = PipelineOptimizer(optimizer)
     return optimizer
示例#6
0
class PipelineOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
        super(PipelineOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []
        self.meta_optimizers_black_list = []

    def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                        user_defined_strategy):
        super(PipelineOptimizer, self)._set_basic_info(
            loss, role_maker, user_defined_optimizer, user_defined_strategy)
        self.num_microbatches = user_defined_strategy.pipeline_configs[
            'micro_batch']

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False

        if self.user_defined_strategy.pipeline == True:
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.pipeline = False
        dist_strategy.pipeline_configs = {}

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.pipeline = True
        dist_strategy.pipeline_configs = {"micro_batch": 1, }

    def _get_local_rank(self, current_endpoint, endpoints):
        cur_node_endpoints = []
        cur_ip = current_endpoint.split(':')[0].strip()
        for ep in endpoints:
            if cur_ip == ep.split(':')[0].strip():
                cur_node_endpoints.append(ep)
        return cur_node_endpoints.index(current_endpoint)

    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
        self.local_rank = self._get_local_rank(current_endpoint, endpoints)
        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches,
                              start_cpu_core_id=self.local_rank)
        node_num = _get_node_num(endpoints)
        gpus_per_node = len(endpoints) // node_num
        self.startup_program = startup_program
        self.local_rank = self._get_local_rank(current_endpoint, endpoints)
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

        loss.block.program._pipeline_opt = dict()
        loss.block.program._pipeline_opt['local_rank'] = self.local_rank
        optimize_ops, params_grads, prog_list = \
            self.wrapped_opt.minimize(loss, startup_program,
                                      parameter_list, no_grad_set)

        assert prog_list
        self.main_program_list = prog_list
        self.main_program = loss.block.program
        self.inner_parallelism = loss.block.program._pipeline_opt[
            'inner_parallelism']
        nranks = len(endpoints)
        self.nranks = nranks
        self.nrings = len(self.main_program_list)

        self.rank = self.role_maker._worker_index()
        self.endpoints = endpoints
        self.current_endpoint = current_endpoint

        pipeline_helper = PipelineHelper(self.role_maker)
        pipeline_helper.update_startup_program(
            self.startup_program._pipeline_opt["startup_program"],
            self.inner_parallelism)

        self._transpile_main_program(loss, node_num, gpus_per_node)
        return optimize_ops, params_grads

    def _transpile_main_program(self, loss, node_num, gpus_per_node):
        self._insert_loss_grad_ops(loss, gpus_per_node, node_num)
        for ring_id in range(1, gpus_per_node + 1):
            self._insert_allreduce_ops(ring_id)

    def _insert_loss_grad_ops(self, loss, gpus_per_node, node_num):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program_list[gpus_per_node - 1][
            'program'].global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / node_num,
                        OP_ROLE_KEY: OpRole.Backward
                    })

    def _insert_allreduce_ops(self, ring_id):
        block = self.main_program_list[ring_id - 1]['program'].global_block()
        origin_block = self.main_program.global_block()
        grad = None
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_backward_op(op) and \
                    OP_ROLE_VAR_KEY in op.attr_names:
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = idx
                for i in range(0, len(op_role_var), 2):
                    param = block.vars[op_role_var[i]]
                    grad = block.vars[op_role_var[i + 1]]
                    origin_param = origin_block.vars[op_role_var[i]]
                    if origin_param.is_distributed:
                        continue
                    if offset == idx:
                        offset += 1
                        block._insert_op(
                            offset,
                            type='c_sync_calc_stream',
                            inputs={'X': grad},
                            outputs={'Out': grad},
                            attrs={OP_ROLE_KEY: OpRole.Backward})
                        offset += 1

                    block._insert_op(
                        offset,
                        type='c_sync_calc_stream',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward
                        })

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
                block._insert_op(
                    idx + ring_id,
                    type='c_sync_comm_stream',
                    inputs={'X': grad},
                    outputs={'Out': grad},
                    attrs={'ring_id': ring_id,
                           OP_ROLE_KEY: OpRole.Backward})
            break
class PipelineOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
        super(PipelineOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []
        self.meta_optimizers_black_list = []

    def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                        user_defined_strategy):
        super(PipelineOptimizer,
              self)._set_basic_info(loss, role_maker, user_defined_optimizer,
                                    user_defined_strategy)
        num_microbatches = user_defined_strategy.pipeline_configs[
            'micro_batch']
        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=num_microbatches)

    def _can_apply(self):
        if self.user_defined_strategy.pipeline == True:
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.pipeline = False
        dist_strategy.pipeline_configs = {}

    def _enable_strategy(self, dist_strategy):
        # we do not support enable pipeline automatically right now
        return

    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        optimize_ops, params_grads, prog_list = \
            self.wrapped_opt.minimize(loss, startup_program,
                                      parameter_list, no_grad_set)
        if self.role_maker.worker_num() == 1:
            return optimize_ops, params_grads

        endpoints = self.role_maker.get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker.worker_index()]
        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

        assert prog_list
        self.main_program_list = prog_list
        self.main_program = loss.block.program
        nranks = len(endpoints)
        self.nranks = nranks
        self.nrings = len(self.main_program_list)

        self.rank = self.role_maker.worker_index()
        self.endpoints = endpoints
        self.current_endpoint = current_endpoint

        pipeline_helper = PipelineHelper(self.role_maker, nrings=self.nrings)
        pipeline_helper.update_startup_program(self.startup_program)

        self._transpile_main_program()
        return optimize_ops, params_grads

    def _transpile_main_program(self):
        self._insert_loss_grad_ops()
        for ring_id in range(self.nrings):
            self._insert_allreduce_ops(ring_id)

    def _insert_loss_grad_ops(self):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program_list[self.nrings -
                                       1]['program'].global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(idx + 1,
                                 type='scale',
                                 inputs={'X': loss_grad_var},
                                 outputs={'Out': loss_grad_var},
                                 attrs={
                                     'scale': 1.0 / self.nranks,
                                     OP_ROLE_KEY: OpRole.Backward
                                 })

    def _insert_allreduce_ops(self, ring_id):
        block = self.main_program_list[ring_id]['program'].global_block()
        origin_block = self.main_program.global_block()
        grad = None
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_backward_op(op) and \
                OP_ROLE_VAR_KEY in op.attr_names:
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = idx
                for i in range(0, len(op_role_var), 2):
                    param = block.vars[op_role_var[i]]
                    grad = block.vars[op_role_var[i + 1]]
                    origin_param = origin_block.vars[op_role_var[i]]
                    if origin_param.is_distributed:
                        continue
                    if offset == idx:
                        offset += 1
                        block._insert_op(offset,
                                         type='c_sync_calc_stream',
                                         inputs={'X': grad},
                                         outputs={'Out': grad},
                                         attrs={OP_ROLE_KEY: OpRole.Backward})
                        offset += 1

                    block._insert_op(offset,
                                     type='c_sync_calc_stream',
                                     inputs={'X': grad},
                                     outputs={'Out': grad},
                                     attrs={
                                         'ring_id': ring_id,
                                         OP_ROLE_KEY: OpRole.Backward
                                     })

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
                block._insert_op(idx + ring_id,
                                 type='c_sync_comm_stream',
                                 inputs={'X': grad},
                                 outputs={'Out': grad},
                                 attrs={
                                     'ring_id': ring_id,
                                     OP_ROLE_KEY: OpRole.Backward
                                 })
            break
示例#8
0
class PipelineOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
        super(PipelineOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
        ]
        self.meta_optimizers_black_list = [
            "GraphExecutionOptimizer",
        ]

    def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                        user_defined_strategy):
        super(PipelineOptimizer,
              self)._set_basic_info(loss, role_maker, user_defined_optimizer,
                                    user_defined_strategy)
        self.micro_batch_size = user_defined_strategy.pipeline_configs[
            'micro_batch_size']
        self.num_microbatches = user_defined_strategy.pipeline_configs[
            'accumulate_steps']
        self.schedule_mode = user_defined_strategy.pipeline_configs[
            'schedule_mode']

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False

        if self.user_defined_strategy.pipeline == True:
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.pipeline = False
        dist_strategy.pipeline_configs = {}

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.pipeline = True
        dist_strategy.pipeline_configs = {
            "micro_batch_size": 1,
            "accumulate_steps": 1,
            "schedule_mode": "1F1B",
        }

    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches)
        node_num = _get_node_num(endpoints)
        gpus_per_node = len(endpoints) // node_num
        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        assert self.nranks % node_num == 0

        loss.block.program._pipeline_opt = dict()
        loss.block.program._pipeline_opt['local_rank'] = self.rank
        loss.block.program._pipeline_opt[
            'micro_batch_size'] = self.micro_batch_size
        loss.block.program._pipeline_opt['schedule_mode'] = self.schedule_mode
        optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set)
        assert prog_list

        self.main_program_list = prog_list
        self.main_program = loss.block.program
        self.inner_parallelism = loss.block.program._pipeline_opt[
            'inner_parallelism']
        assert self.nranks % self.inner_parallelism == 0

        pipeline_helper = PipelineHelper(self.role_maker)
        pipeline_helper.update_startup_program(
            self.startup_program._pipeline_opt["startup_program"],
            self.inner_parallelism)

        pipeline_num = self.nranks // self.inner_parallelism
        self._transpile_main_program(loss, pipeline_num,
                                     self.inner_parallelism)
        return optimize_ops, params_grads

    def _transpile_main_program(self, loss, pipeline_num, inner_parallelism):
        if pipeline_num <= 1: return
        self._insert_loss_grad_ops(loss, pipeline_num)
        for ring_id in range(1, inner_parallelism + 1):
            self._insert_allreduce_ops(ring_id)

    def _insert_loss_grad_ops(self, loss, pipeline_num):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program_list[-1]['program'].global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(idx + 1,
                                 type='scale',
                                 inputs={'X': loss_grad_var},
                                 outputs={'Out': loss_grad_var},
                                 attrs={
                                     'scale': 1.0 / pipeline_num,
                                     OP_ROLE_KEY: OpRole.Backward
                                 })

    def _insert_allreduce_ops(self, ring_id):
        block = self.main_program_list[ring_id - 1]['program'].global_block()
        origin_block = self.main_program.global_block()
        grad = None
        processed_param_name = set()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_backward_op(op) and \
                    OP_ROLE_VAR_KEY in op.attr_names:
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = idx
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.vars[op_role_var[i]]
                    if param_name in processed_param_name: continue
                    processed_param_name.add(param_name)
                    grad = block.vars[op_role_var[i + 1]]
                    origin_param = origin_block.vars[op_role_var[i]]
                    if origin_param.is_distributed:
                        continue
                    if offset == idx:
                        offset += 1
                        block._insert_op(offset,
                                         type='c_sync_calc_stream',
                                         inputs={'X': grad},
                                         outputs={'Out': grad},
                                         attrs={OP_ROLE_KEY: OpRole.Backward})
                        offset += 1

                    block._insert_op(offset,
                                     type='c_allreduce_sum',
                                     inputs={'X': grad},
                                     outputs={'Out': grad},
                                     attrs={
                                         'ring_id': ring_id,
                                         OP_ROLE_KEY: OpRole.Backward
                                     })

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
                block._insert_op(idx,
                                 type='c_sync_comm_stream',
                                 inputs={'X': grad},
                                 outputs={'Out': grad},
                                 attrs={
                                     'ring_id': ring_id,
                                     OP_ROLE_KEY: OpRole.Backward
                                 })
            break
class PipelineOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
        super(PipelineOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
        ]
        self.meta_optimizers_black_list = [
            "GraphExecutionOptimizer",
        ]
        self.global_ring_id = 1
        self.dp_ring_id = 2
        self.start_pipeline_ring_id = 20  # Just a magic number

    def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                        user_defined_strategy):
        super(PipelineOptimizer,
              self)._set_basic_info(loss, role_maker, user_defined_optimizer,
                                    user_defined_strategy)
        self.micro_batch_size = user_defined_strategy.pipeline_configs[
            'micro_batch_size']
        self.num_microbatches = user_defined_strategy.pipeline_configs[
            'accumulate_steps']
        self.schedule_mode = user_defined_strategy.pipeline_configs[
            'schedule_mode']
        self.use_sharding = user_defined_strategy.sharding

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False

        # FIXME revise for hybrid parallelism
        if self.use_sharding:
            return False

        if self.user_defined_strategy.pipeline == True:
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.pipeline = False
        dist_strategy.pipeline_configs = {
            "micro_batch_size": 1,
            "accumulate_steps": 1,
            "schedule_mode": "1F1B",
        }

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.pipeline = True
        dist_strategy.pipeline_configs = {
            "micro_batch_size": 1,
            "accumulate_steps": 1,
            "schedule_mode": "1F1B",
        }

    def _broadcast_params(self, ring_id):
        block = self.startup_program.global_block()
        param = None
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

            block.append_op(type='c_broadcast',
                            inputs={'X': param},
                            outputs={'Out': param},
                            attrs={
                                'ring_id': ring_id,
                                'root': 0,
                                OP_ROLE_KEY: OpRole.Forward
                            })

        if not param: return  # no parameter on this device
        block.append_op(type='c_sync_comm_stream',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Forward
                        })

    def _get_process_group_info(self):
        # global ring info
        self.global_endpoints = self.endpoints
        self.global_rank = self.rank
        self.global_nranks = self.nranks

        # data parallel ring info
        if self.pipeline_num > 1:
            self.dp_rank = self.rank // self.inner_parallelism
            self.dp_nranks = self.nranks // self.inner_parallelism
            start_index = self.rank % self.inner_parallelism
            self.dp_endpoints = [
                self.endpoints[start_index + i * self.inner_parallelism]
                for i in range(self.pipeline_num)
            ]

    def _init_process_group(self, pipeline_pair, pipeline_ring_map):
        self._get_process_group_info()
        collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
        # Create global ring for all gpus (ring_id = 0)
        collective_helper._init_communicator(self.startup_program,
                                             self.current_endpoint,
                                             self.global_endpoints,
                                             self.global_rank,
                                             self.global_ring_id, True,
                                             self.global_ring_id, True)
        # Create pipeline rings
        if self.inner_parallelism > 1:
            pipeline_id = self.rank // self.inner_parallelism
            start_index = pipeline_id * self.inner_parallelism
            for pair in pipeline_pair:
                pair_key = pair[0] * 1000 + pair[1]
                ring_id = pipeline_ring_map[pair_key]
                assert ring_id >= self.start_pipeline_ring_id
                first_node = pair[0] + start_index
                second_node = pair[1] + start_index
                if self.rank != first_node and self.rank != second_node:
                    collective_helper._init_communicator(
                        self.startup_program, None, None, None, None, False,
                        self.global_ring_id, True)
                    continue
                pipeline_endpoints = [
                    self.endpoints[first_node], self.endpoints[second_node]
                ]
                pipeline_rank = 0 if self.rank == first_node else 1
                pipeline_nranks = 2
                collective_helper._init_communicator(self.startup_program,
                                                     self.current_endpoint,
                                                     pipeline_endpoints,
                                                     pipeline_rank, ring_id,
                                                     False,
                                                     self.global_ring_id, True)

        # Create dp rings
        if self.pipeline_num > 1:
            collective_helper._init_communicator(
                self.startup_program, self.current_endpoint, self.dp_endpoints,
                self.dp_rank, self.dp_ring_id, True, self.global_ring_id, True)
            self._broadcast_params(self.dp_ring_id)

    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()

        self.wrapped_opt = PO(self.inner_opt,
                              num_microbatches=self.num_microbatches)
        orig_startup_program = startup_program if startup_program else fluid.default_startup_program(
        )
        block = loss.block
        program = block.program

        program._pipeline_opt = dict()
        program._pipeline_opt['local_rank'] = self.rank
        program._pipeline_opt['global_ring_id'] = self.global_ring_id
        program._pipeline_opt['ring_id'] = self.start_pipeline_ring_id
        program._pipeline_opt['micro_batch_size'] = self.micro_batch_size
        program._pipeline_opt['schedule_mode'] = self.schedule_mode
        program._pipeline_opt['use_sharding'] = False
        program._pipeline_opt['mp_degree'] = 1
        program._pipeline_opt['mp_rank'] = 0
        optimize_ops, params_grads, prog_list, pp_pair, ring_map = self.wrapped_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set)
        self.startup_program = orig_startup_program._pipeline_opt[
            'startup_program']
        self.inner_parallelism = program._pipeline_opt['inner_parallelism']
        assert self.nranks % self.inner_parallelism == 0
        assert prog_list
        self.pipeline_num = len(self.endpoints) // self.inner_parallelism

        self._init_process_group(pp_pair, ring_map)

        self.main_program_list = prog_list
        self.main_program = program
        if self.pipeline_num > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads

    def _transpile_main_program(self, loss):
        self._insert_loss_grad_ops(loss, self.pipeline_num)
        self._insert_allreduce_ops(self.dp_ring_id)

    def _insert_loss_grad_ops(self, loss, pipeline_num):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program_list[-1].global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(idx + 1,
                                 type='scale',
                                 inputs={'X': loss_grad_var},
                                 outputs={'Out': loss_grad_var},
                                 attrs={
                                     'scale': 1.0 / pipeline_num,
                                     OP_ROLE_KEY: OpRole.Backward
                                 })

    def _insert_allreduce_ops(self, ring_id):
        block = self.main_program._pipeline_opt[
            'section_program'].global_block()
        origin_block = self.main_program.global_block()
        grad = None
        processed_param_name = set()
        first_optimize_op_idx = None
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_backward_op(op) and not first_optimize_op_idx:
                first_optimize_op_idx = idx + 1
                # no optimize phase
                if first_optimize_op_idx == len(block.ops): return
            if is_backward_op(op) and \
                    OP_ROLE_VAR_KEY in op.attr_names:
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = 0
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.vars[op_role_var[i]]
                    if param_name in processed_param_name: continue
                    processed_param_name.add(param_name)
                    grad_name = op_role_var[i + 1]
                    if not 'MERGED' in grad_name: grad_name += '@MERGED'
                    grad = block.vars[grad_name]
                    origin_param = origin_block.vars[op_role_var[i]]
                    if origin_param.is_distributed:
                        continue

                    block._insert_op(first_optimize_op_idx + offset,
                                     type='c_allreduce_sum',
                                     inputs={'X': grad},
                                     outputs={'Out': grad},
                                     attrs={
                                         'ring_id': ring_id,
                                         'use_calc_stream': True,
                                         OP_ROLE_KEY: OpRole.Optimize
                                     })