# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G from mindspore.ops import _constants as Constants # pylint: disable=unused-variable tuple_getitem = Primitive(Constants.kTupleGetItem) add = P.Add() allreduce = P.AllReduce() allreduce.add_prim_attr('fusion', 1) make_tuple = Primitive("make_tuple") conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1) bn = P.FusedBatchNorm() relu = P.ReLU() conv_bn1 = Primitive('ConvBN1') bn2_add_relu = Primitive('BN2AddRelu') bn2_relu = Primitive('BN2Relu') fused_bn1 = Primitive('FusedBN1') fused_bn2 = Primitive('FusedBN2') fused_bn3 = Primitive('FusedBN3') bn_grad = G.FusedBatchNormGrad() bn_grad1 = Primitive('BNGrad1')
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) BatchNorm = P.BatchNorm(is_training=True) BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import operations as P from mindspore.ops import Primitive tuple_getitem = Primitive('tuple_getitem') add = P.TensorAdd() max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) make_tuple = Primitive('make_tuple') transdata = Primitive("TransData") Transpose = P.Transpose() class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name):
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import operations as P from mindspore.ops import Primitive add = P.TensorAdd() mul = P.Mul() fused_mul_add = Primitive('FusedMulAdd') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P relu = P.ReLU() relu_grad = Primitive('ReluGrad') relu_v2 = Primitive('ReLUV2') relu_grad_v2 = Primitive('ReluGradV2') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G from mindspore.ops import Primitive make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') bn_grad = G.FusedBatchNormGrad() bn_grad1 = Primitive('BNGrad1') bn_grad2 = Primitive('BNGrad2') bn_grad3 = Primitive('BNGrad3') bn_training_update_grad = Primitive('BNTrainingUpdateGrad') bn_training_reduce_grad = Primitive('BNTrainingReduceGrad') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import _constants as Constants from mindspore.ops import operations as P mul = P.Mul() reduce_sum = P.ReduceSum(keep_dims=True) sub = P.Sub() confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad') make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) axis = 2 class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
""" opt_test """ import numpy as np from mindspore import Tensor from mindspore.ops import Primitive from mindspore.ops import _constants as Constants from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G # pylint: disable=unused-variable # opt test data, not for running # pylint: disable=unused-argument # pylint: disable=redefined-outer-name scalar_add = Primitive(Constants.kScalarAdd) scalar_mul = Primitive(Constants.kScalarMul) tuple_getitem = Primitive(Constants.kTupleGetItem) switch = Primitive('Switch') def test_sexp_conversion(): """ test_sexp_conversion """ return scalar_mul(10, scalar_add(5, 4)) class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn):
# limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import _constants as Constants mul = P.Mul() add = P.Add() square = P.Square() sqrt = P.Sqrt() real_div = P.RealDiv() sub = P.Sub() Assign = P.Assign() make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import mindspore as ms import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G make_tuple = Primitive('MakeTuple') reshape = P.Reshape() backend_reshape = Primitive('Reshape') cast = P.Cast() backend_cast = Primitive('Cast') transpose = P.Transpose() backend_transpose = Primitive('Transpose') onehot1 = P.OneHot() onehot2 = P.OneHot() backend_onehot1 = Primitive('OneHot') backend_onehot2 = Primitive('OneHot') stridedslicegrad = G.StridedSliceGrad() backend_stridedslicegrad = Primitive('StridedSliceGrad') on_value = Tensor(1.0, mstype.float32) off_value = Tensor(0.0, mstype.float32) depth = Tensor(2, mstype.int32)
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants Sub = P.Sub() Mul = P.Mul() RealDiv = P.RealDiv() Select = P.Select() Greater = P.Greater() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P mul = P.Mul() reduce_sum = P.ReduceSum() confusion_mul_grad = Primitive('ConfusionMulGrad') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') axis = 2 class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import numpy as np from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P make_tuple = Primitive('make_tuple') concat = P.Concat() add = P.Add() t1 = Tensor(np.random.randn(1, 11, 20, 1, 1).astype(np.float32)) t2 = Tensor(np.random.randn(1, 11, 20, 1, 1).astype(np.float32)) class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name):
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from dataclasses import dataclass import numpy as np import mindspore as ms from mindspore.common.tensor import Tensor from mindspore.model_zoo.resnet import resnet50 from mindspore.ops import Primitive scala_add = Primitive('scalar_add') @dataclass class Point: x: float y: float def abs(self): return (self.x**2 + self.y**2)**0.5 def scalar_add(x, y): """Implement `scalar_add`.""" return x + y
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants select = P.Select() maximum = P.Maximum() minimum = P.Minimum() greater = P.Greater() real_div = P.RealDiv() mul = P.Mul() sub = P.Sub() lamb_update_with_lr = Primitive('LambUpdateWithLR') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import _constants as Constants lars_v2 = Primitive('LarsV2') square_sum_all = Primitive('SquareSumAll') lars_v2_update = Primitive('LarsV2Update') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P add = P.TensorAdd() sub = P.Sub() make_tuple = Primitive('make_tuple') four2five = Primitive('Four2Five') five2four = Primitive('Five2Four') transdata = Primitive("TransData") cast = Primitive('Cast') depend = Primitive('depend') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name):
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants split = P.Split(0, 8) make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) splitv = Primitive('SplitV') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants depend = P.Depend() all_reduce = P.AllReduce() broadcast = P.Broadcast(1) memcpy_async = Primitive('memcpy_async') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) assign_add = P.AssignAdd() apply_momentun = P.ApplyMomentum() relu = P.ReLU() class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name):
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P Add = P.TensorAdd() Mul = P.Mul() RealDiv = P.RealDiv() Rsqrt = P.Rsqrt() Sqrt = P.Sqrt() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') LambNextMV = Primitive('LambNextMV') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name] def test_lamb_next_mv_rule_cond4(tag): fns = FnDict()
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants addn = P.AddN() mul = P.Mul() fused_mul_addn = Primitive('FusedMulAddN') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) scalar = Tensor(1.0, mstype.float32) class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) reduce_min = P.ReduceMin(keep_dims=False) reduce_min1 = Primitive('ReduceMin') reduce_min2 = Primitive('ReduceMin') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P Transpose = P.Transpose() Reshape = P.Reshape() ConfusionTransposeD = Primitive('ConfusionTransposeD') make_tuple = Primitive('make_tuple') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name] def test_transpose_reshape_fusion(tag):
def construct(self, x): x = Primitive('depend')(x, self.assign(self.s1, x + self.s1)) self.s1 = self.sub(self.s1, x) self.s2 = self.sub(self.s2, x) return x
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P Add = P.TensorAdd() Sub = P.Sub() Mul = P.Mul() RealDiv = P.RealDiv() Sqrt = P.Sqrt() Square = P.Square() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') AdamApplyOne = Primitive('AdamApplyOne') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') BatchNorm = P.BatchNorm(is_training=True) BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) unsorted_segment_sum = P.UnsortedSegmentSum() num_segments = 4 padding = Primitive('Padding') op_slice = Primitive('Slice') op_unsorted_segment_sum = Primitive('UnsortedSegmentSum') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn
# limitations under the License. # ============================================================================ """ test_multitype """ import numpy as np from mindspore import Tensor from mindspore.common.api import ms_function from mindspore.common.parameter import Parameter from mindspore.ops import Primitive from mindspore.ops import composite as C from mindspore.ops import operations as P from ...ut_filter import non_graph_engine tensor_add = P.TensorAdd() op_add = P.AddN() scala_add = Primitive('scalar_add') add = C.MultitypeFuncGraph('add') @add.register("Number", "Number") def add_scala(x, y): return scala_add(x, y) @add.register("Tensor", "Tensor") def add_tensor(x, y): return tensor_add(x, y) @ms_function def mainf(x, y):
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants Add = P.Add() Mul = P.Mul() RealDiv = P.RealDiv() Rsqrt = P.Rsqrt() Sqrt = P.Sqrt() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) LambNextMV = Primitive('LambNextMV') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P addn = P.AddN() make_tuple = Primitive('MakeTuple') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name] def test_addn_fission(tag): """ test_adam_apply_one_with_decay_rule """