def __init__(self): super(Net, self).__init__() self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') self.broadcast1 = P.Broadcast(0) self.broadcast2 = P.Broadcast(1) self.broadcast3 = P.Broadcast(2)
def broadcast_params(self, optim_result): """ Apply Broadcast operations in the sequential order of parameter groups. Returns: bool, the status flag. """ param_group = [] key_group = [] for _ in range(self.dev_num): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): param_group[self.param_rank[i]] = param_group[ self.param_rank[i]] + (self.parameters[i], ) key = P.MakeRefKey(self.param_names[i])() key_group[ self.param_rank[i]] = key_group[self.param_rank[i]] + (key, ) new_param_group = [] for root in range(self.dev_num): ops = P.Broadcast(root) next_params = ops(param_group[root]) new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): status = F.depend( F.control_depend(new_param_group[i], new_param_group[i + 1][0]), status) return status
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants depend = P.Depend() all_reduce = P.AllReduce() broadcast = P.Broadcast(1) tensor_move = Primitive('TensorMove') make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) assign_add = P.AssignAdd() apply_momentun = P.ApplyMomentum() relu = P.ReLU() class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn