def broadcast_input_data(hcg, *inputs, **kwargs): for v in inputs: if isinstance(v, core.VarBase): with framework.no_grad(): _broadcast_data_help(v, v.shape, v.dtype, hcg) else: logger.error("it doesn't support data type {}".format(type(v))) for k, v in kwargs.items(): if isinstance(v, core.VarBase): with framework.no_grad(): _broadcast_data_help(v, v.shape, v.dtype, hcg) kwargs[k] = v else: logger.error("it doesn't support data type {}".format(type(v))) return inputs, kwargs
def sharding_reduce_gradients(parameter_list, hcg): # TODO allreduce --> reduce # TODO merge grad / nrank with dp logger.debug("sharding start gradients sync") with framework.no_grad(): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: if param.trainable and (param._grad_ivar() is not None): g_var = param._grad_ivar() # need use trace_op to allreduce # paddle.distributed.all_reduce( # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) paddle.fluid.framework._dygraph_tracer().trace_op( type="c_allreduce_sum", inputs={'X': g_var}, outputs={'Out': g_var}, attrs={ 'ring_id': hcg.get_sharding_parallel_group().id, 'use_calc_stream': True }) # grad / sharding_rank div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype) paddle.fluid.framework._dygraph_tracer().trace_op( type="elementwise_div", inputs={ 'X': g_var, 'Y': div_factor }, outputs={'Out': g_var}, attrs={'axis': -1})
def fused_allreduce_gradients(parameter_list, hcg): data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() logger.debug("dp start fuse allreduce gradients") apply_func = _apply_collective_grads_eager if in_dygraph_mode( ) else _apply_collective_grads with framework.no_grad(): apply_func(parameter_list, data_parallel_group)
def _sharding_sync_parameters(self): """ sync parameter across sharding group """ # TODO speed up this functional logger.debug("sharding start sync parameters") with framework.no_grad(): # TODO detach not need (?) for rank, params in self._rank2params.items(): for param in params: paddle.distributed.broadcast( param, # the collective API need src rank to be the global rank id # instead of the relative logic rank id within group src=self._hcg.get_sharding_parallel_group(). ranks[rank], group=self._hcg.get_sharding_parallel_group(), use_calc_stream=True)
def fused_allreduce_gradients(parameter_list, hcg): data_parallel_group = None if hcg is None else hcg.get_data_parallel_group( ) logger.debug("dp start fuse allreduce gradients") with framework.no_grad(): _apply_collective_grads(parameter_list, data_parallel_group)