Exemplo n.º 1
0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants

# pylint: disable=unused-variable

tuple_getitem = Primitive(Constants.kTupleGetItem)
add = P.Add()
allreduce = P.AllReduce()
allreduce.add_prim_attr('fusion', 1)
make_tuple = Primitive("make_tuple")
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
bn = P.FusedBatchNorm()
relu = P.ReLU()
conv_bn1 = Primitive('ConvBN1')
bn2_add_relu = Primitive('BN2AddRelu')
bn2_relu = Primitive('BN2Relu')
fused_bn1 = Primitive('FusedBN1')
fused_bn2 = Primitive('FusedBN2')
fused_bn3 = Primitive('FusedBN3')
bn_grad = G.FusedBatchNormGrad()
bn_grad1 = Primitive('BNGrad1')
Exemplo n.º 2
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
BatchNorm = P.BatchNorm(is_training=True)
BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
from mindspore.ops import Primitive

tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2)
make_tuple = Primitive('make_tuple')
transdata = Primitive("TransData")
Transpose = P.Transpose()


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
Exemplo n.º 4
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
from mindspore.ops import Primitive

add = P.TensorAdd()
mul = P.Mul()
fused_mul_add = Primitive('FusedMulAdd')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Exemplo n.º 5
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

relu = P.ReLU()
relu_grad = Primitive('ReluGrad')
relu_v2 = Primitive('ReLUV2')
relu_grad_v2 = Primitive('ReluGradV2')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Exemplo n.º 6
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import Primitive

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
bn_grad = G.FusedBatchNormGrad()
bn_grad1 = Primitive('BNGrad1')
bn_grad2 = Primitive('BNGrad2')
bn_grad3 = Primitive('BNGrad3')
bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P

mul = P.Mul()
reduce_sum = P.ReduceSum(keep_dims=True)
sub = P.Sub()
confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad')
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
axis = 2


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Exemplo n.º 8
0
""" opt_test """
import numpy as np

from mindspore import Tensor
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G

# pylint: disable=unused-variable

# opt test data, not for running
# pylint: disable=unused-argument
# pylint: disable=redefined-outer-name

scalar_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive(Constants.kScalarMul)
tuple_getitem = Primitive(Constants.kTupleGetItem)
switch = Primitive('Switch')


def test_sexp_conversion():
    """ test_sexp_conversion """
    return scalar_mul(10, scalar_add(5, 4))


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import _constants as Constants

mul = P.Mul()
add = P.Add()
square = P.Square()
sqrt = P.Sqrt()
real_div = P.RealDiv()
sub = P.Sub()
Assign = P.Assign()
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Exemplo n.º 10
0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G

make_tuple = Primitive('MakeTuple')
reshape = P.Reshape()
backend_reshape = Primitive('Reshape')
cast = P.Cast()
backend_cast = Primitive('Cast')
transpose = P.Transpose()
backend_transpose = Primitive('Transpose')
onehot1 = P.OneHot()
onehot2 = P.OneHot()
backend_onehot1 = Primitive('OneHot')
backend_onehot2 = Primitive('OneHot')
stridedslicegrad = G.StridedSliceGrad()
backend_stridedslicegrad = Primitive('StridedSliceGrad')
on_value = Tensor(1.0, mstype.float32)
off_value = Tensor(0.0, mstype.float32)
depth = Tensor(2, mstype.int32)
Exemplo n.º 11
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

Sub = P.Sub()
Mul = P.Mul()
RealDiv = P.RealDiv()
Select = P.Select()
Greater = P.Greater()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
LambUpdateWithLrV2 = Primitive('LambUpdateWithLrV2')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Exemplo n.º 12
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

mul = P.Mul()
reduce_sum = P.ReduceSum()
confusion_mul_grad = Primitive('ConfusionMulGrad')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
axis = 2


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P

make_tuple = Primitive('make_tuple')
concat = P.Concat()
add = P.Add()

t1 = Tensor(np.random.randn(1, 11, 20, 1, 1).astype(np.float32))
t2 = Tensor(np.random.randn(1, 11, 20, 1, 1).astype(np.float32))


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
Exemplo n.º 14
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from dataclasses import dataclass
import numpy as np

import mindspore as ms
from mindspore.common.tensor import Tensor
from mindspore.model_zoo.resnet import resnet50
from mindspore.ops import Primitive

scala_add = Primitive('scalar_add')


@dataclass
class Point:
    x: float
    y: float

    def abs(self):
        return (self.x**2 + self.y**2)**0.5


def scalar_add(x, y):
    """Implement `scalar_add`."""
    return x + y
Exemplo n.º 15
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

select = P.Select()
maximum = P.Maximum()
minimum = P.Minimum()
greater = P.Greater()
real_div = P.RealDiv()
mul = P.Mul()
sub = P.Sub()
lamb_update_with_lr = Primitive('LambUpdateWithLR')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Exemplo n.º 16
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants

lars_v2 = Primitive('LarsV2')
square_sum_all = Primitive('SquareSumAll')
lars_v2_update = Primitive('LarsV2Update')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Exemplo n.º 17
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

add = P.TensorAdd()
sub = P.Sub()
make_tuple = Primitive('make_tuple')
four2five = Primitive('Four2Five')
five2four = Primitive('Five2Four')
transdata = Primitive("TransData")
cast = Primitive('Cast')
depend = Primitive('depend')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
Exemplo n.º 18
0
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

split = P.Split(0, 8)
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
splitv = Primitive('SplitV')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Exemplo n.º 19
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

depend = P.Depend()
all_reduce = P.AllReduce()
broadcast = P.Broadcast(1)
memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
assign_add = P.AssignAdd()
apply_momentun = P.ApplyMomentum()
relu = P.ReLU()


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
Exemplo n.º 20
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

Add = P.TensorAdd()
Mul = P.Mul()
RealDiv = P.RealDiv()
Rsqrt = P.Rsqrt()
Sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
LambNextMV = Primitive('LambNextMV')

class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

def test_lamb_next_mv_rule_cond4(tag):
    fns = FnDict()
Exemplo n.º 21
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

addn = P.AddN()
mul = P.Mul()
fused_mul_addn = Primitive('FusedMulAddN')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
scalar = Tensor(1.0, mstype.float32)


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Exemplo n.º 22
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
reduce_min = P.ReduceMin(keep_dims=False)
reduce_min1 = Primitive('ReduceMin')
reduce_min2 = Primitive('ReduceMin')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

Transpose = P.Transpose()
Reshape = P.Reshape()
ConfusionTransposeD = Primitive('ConfusionTransposeD')
make_tuple = Primitive('make_tuple')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]


def test_transpose_reshape_fusion(tag):
Exemplo n.º 24
0
 def construct(self, x):
     x = Primitive('depend')(x, self.assign(self.s1, x + self.s1))
     self.s1 = self.sub(self.s1, x)
     self.s2 = self.sub(self.s2, x)
     return x
Exemplo n.º 25
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

Add = P.TensorAdd()
Sub = P.Sub()
Mul = P.Mul()
RealDiv = P.RealDiv()
Sqrt = P.Sqrt()
Square = P.Square()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
AdamApplyOne = Primitive('AdamApplyOne')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
BatchNorm = P.BatchNorm(is_training=True)
BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
unsorted_segment_sum = P.UnsortedSegmentSum()
num_segments = 4
padding = Primitive('Padding')
op_slice = Primitive('Slice')
op_unsorted_segment_sum = Primitive('UnsortedSegmentSum')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn
Exemplo n.º 28
0
# limitations under the License.
# ============================================================================
""" test_multitype """
import numpy as np

from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from ...ut_filter import non_graph_engine

tensor_add = P.TensorAdd()
op_add = P.AddN()
scala_add = Primitive('scalar_add')
add = C.MultitypeFuncGraph('add')


@add.register("Number", "Number")
def add_scala(x, y):
    return scala_add(x, y)


@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return tensor_add(x, y)


@ms_function
def mainf(x, y):
Exemplo n.º 29
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants

Add = P.Add()
Mul = P.Mul()
RealDiv = P.RealDiv()
Rsqrt = P.Rsqrt()
Sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
LambNextMV = Primitive('LambNextMV')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Exemplo n.º 30
0
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops import operations as P

addn = P.AddN()
make_tuple = Primitive('MakeTuple')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]


def test_addn_fission(tag):
    """ test_adam_apply_one_with_decay_rule """