def __init__(self):
        super(Net, self).__init__()
        self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1')
        self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2')
        self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3')

        self.broadcast1 = P.Broadcast(0)
        self.broadcast2 = P.Broadcast(1)
        self.broadcast3 = P.Broadcast(2)
Пример #2
0
    def broadcast_params(self, optim_result):
        """
        Apply Broadcast operations in the sequential order of parameter groups.

        Returns:
             bool, the status flag.
        """
        param_group = []
        key_group = []
        for _ in range(self.dev_num):
            param_group.append(F.make_tuple())
            key_group.append(F.make_tuple())
        for i in range(self.param_length):
            param_group[self.param_rank[i]] = param_group[
                self.param_rank[i]] + (self.parameters[i], )
            key = P.MakeRefKey(self.param_names[i])()
            key_group[
                self.param_rank[i]] = key_group[self.param_rank[i]] + (key, )
        new_param_group = []
        for root in range(self.dev_num):
            ops = P.Broadcast(root)
            next_params = ops(param_group[root])
            new_param_group.append(next_params)
            for i in range(F.tuple_len(next_params)):
                F.assign(key_group[root][i], next_params[i])
        status = F.control_depend(optim_result, new_param_group[0][0])
        for i in range(self.dev_num - 1):
            status = F.depend(
                F.control_depend(new_param_group[i],
                                 new_param_group[i + 1][0]), status)

        return status
Пример #3
0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

depend = P.Depend()
all_reduce = P.AllReduce()
broadcast = P.Broadcast(1)
tensor_move = Primitive('TensorMove')
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
assign_add = P.AssignAdd()
apply_momentun = P.ApplyMomentum()
relu = P.ReLU()


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn