# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """grad_reducer_thor""" import mindspore.common.dtype as mstype from mindspore.communication.management import GlobalComm, get_group_size from mindspore.nn.cell import Cell from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp reduce_opt = C.MultitypeFuncGraph("reduce_opt") _all_reduce_G = AllReduce() def _init_optimizer_allreduce(group): global _all_reduce_G _all_reduce_G = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce_G.add_prim_attr('fusion', group) @reduce_opt.register("Function", "Number", "Tensor") def _tensors_allreduce_mean(mul, degree, grad): degree = F.scalar_cast(degree, F.dtype(grad)) grad = _all_reduce_G(grad) cast_op = P.Cast() return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
def __init__(self, input_channel, out_channel, op): super(AllReduceNet, self).__init__() self.dense = Dense(input_channel, out_channel) self.reduce = AllReduce(op) self.relu = ReLU()
def _init_optimizer_allreduce(): global _all_reduce _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1)
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """grad reducer cell for distributed training""" from mindspore.nn.cell import Cell from mindspore.communication.management import GlobalComm, get_group_size from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") _all_reduce = AllReduce() def _init_optimizer_allreduce(): global _all_reduce _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1) @reduce_opt.register("Function", "Number", "Bool", "Tensor") def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): """ Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning. Args: mul (Primitive): Div operation.
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """grad_reducer_thor""" import mindspore.common.dtype as mstype from mindspore.communication.management import GlobalComm, get_group_size from mindspore.nn.cell import Cell from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp reduce_opt = C.MultitypeFuncGraph("reduce_opt") _all_reduce_A = AllReduce() def _init_optimizer_allreduce(group): global _all_reduce_A _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce_A.add_prim_attr('fusion', group) @reduce_opt.register("Function", "Number", "Tensor") def _tensors_allreduce_mean(mul, degree, grad): degree = F.scalar_cast(degree, F.dtype(grad)) grad = _all_reduce_A(grad) cast_op = P.Cast() return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
def test_all_reduce(x): print('test_all_reduce with %s' % (x)) all_reduce = AllReduce() y = all_reduce(x) print('y=%s' % (y))