def broadcast_input_data(hcg, *inputs, **kwargs):
    for v in inputs:
        if isinstance(v, core.VarBase):
            with framework.no_grad():
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
        else:
            logger.error("it doesn't support data type {}".format(type(v)))

    for k, v in kwargs.items():
        if isinstance(v, core.VarBase):
            with framework.no_grad():
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
            kwargs[k] = v
        else:
            logger.error("it doesn't support data type {}".format(type(v)))
    return inputs, kwargs
def sharding_reduce_gradients(parameter_list, hcg):
    # TODO allreduce --> reduce
    # TODO merge grad / nrank with dp
    logger.debug("sharding start gradients sync")
    with framework.no_grad():

        sharding_nrank = hcg.get_sharding_parallel_group().nranks
        for param in parameter_list:
            if param.trainable and (param._grad_ivar() is not None):

                g_var = param._grad_ivar()

                # need use trace_op to allreduce
                # paddle.distributed.all_reduce(
                #     g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
                paddle.fluid.framework._dygraph_tracer().trace_op(
                    type="c_allreduce_sum",
                    inputs={'X': g_var},
                    outputs={'Out': g_var},
                    attrs={
                        'ring_id': hcg.get_sharding_parallel_group().id,
                        'use_calc_stream': True
                    })

                # grad / sharding_rank
                div_factor = paddle.to_tensor(sharding_nrank,
                                              dtype=g_var.dtype)
                paddle.fluid.framework._dygraph_tracer().trace_op(
                    type="elementwise_div",
                    inputs={
                        'X': g_var,
                        'Y': div_factor
                    },
                    outputs={'Out': g_var},
                    attrs={'axis': -1})
def fused_allreduce_gradients(parameter_list, hcg):
    data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
    logger.debug("dp start fuse allreduce gradients")
    apply_func = _apply_collective_grads_eager if in_dygraph_mode(
    ) else _apply_collective_grads
    with framework.no_grad():
        apply_func(parameter_list, data_parallel_group)
Example #4
0
    def _sharding_sync_parameters(self):
        """
        sync parameter across sharding group
        """
        # TODO speed up this functional

        logger.debug("sharding start sync parameters")
        with framework.no_grad():
            # TODO detach not need (?)
            for rank, params in self._rank2params.items():
                for param in params:
                    paddle.distributed.broadcast(
                        param,
                        # the collective API need src rank to be the global rank id
                        # instead of the relative logic rank id within group
                        src=self._hcg.get_sharding_parallel_group().
                        ranks[rank],
                        group=self._hcg.get_sharding_parallel_group(),
                        use_calc_stream=True)
def fused_allreduce_gradients(parameter_list, hcg):
    data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
    )
    logger.debug("dp start fuse allreduce gradients")
    with framework.no_grad():
        _apply_collective_grads(parameter_list, data_parallel_group)