def __init__(self, is_training): super(Net, self).__init__() self.fused_bn_grad_ex = G.BatchNormGrad(is_training=is_training, epsilon=1e-5)
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops.operations import _grad_ops as G from mindspore.ops import _constants as Constants make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) BatchNormGradTraining = G.BatchNormGrad(is_training=True) BatchNormGradInfer = G.BatchNormGrad(is_training=False) BNInferGrad = Primitive('BNInferGrad') BNTrainingUpdateGrad = Primitive('BNTrainingUpdateGrad') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
'block': P.FusedBatchNorm(), 'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], 'skip': []}), ('FusedBatchNormGrad', { 'block': G.FusedBatchNormGrad(), 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], 'skip': ['backward']}), ('BatchNorm', { 'block': P.BatchNorm(), 'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': []}), ('BatchNormGrad', { 'block': G.BatchNormGrad(), 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), ('ApplyMomentum', { 'block': P.ApplyMomentum(), 'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64], [32, 32, 64], [32, 32, 64], [32, 32, 64]], 'desc_bprop': [[128, 32, 32, 64]], 'skip': ['backward']}), ('TopK', { 'block': P.TopK(), 'desc_const': [5], 'desc_inputs': [[20, 20, 10]], 'desc_bprop': [[20, 20, 5]], 'skip': ['backward']}),
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops.operations import _grad_ops as G from mindspore.ops import _constants as Constants from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype make_tuple = Primitive('MakeTuple') tuple_getitem = Primitive(Constants.kTupleGetItem) bn_grad = G.BatchNormGrad(is_training=True) sync_bn_grad = G.SyncBatchNormGrad() bn_grad1 = Primitive('BNGrad1') bn_grad2 = Primitive('BNGrad2') bn_grad3 = Primitive('BNGrad3') bn_training_update_grad = Primitive('BNTrainingUpdateGrad') bn_training_reduce_grad = Primitive('BNTrainingReduceGrad') allreduce = Primitive('AllReduce') mul = Primitive('Mul') mul_value = Tensor(0.5, mstype.float32) class FnDict: def __init__(self): self.fnDict = {}
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops.operations import _grad_ops as G batch_norm_grad = G.BatchNormGrad(is_training=False) bn_infer_grad = Primitive('BNInferGrad') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
('FusedBatchNormGrad', { 'block': G.FusedBatchNormGrad(), 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], 'skip': ['backward'] }), ('BatchNorm', { 'block': P.BatchNorm(), 'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': [] }), ('BatchNormGrad', { 'block': G.BatchNormGrad(), 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward'] }), ('ApplyMomentum', { 'block': P.ApplyMomentum(), 'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64], [32, 32, 64], [32, 32, 64], [32, 32, 64]], 'desc_bprop': [[128, 32, 32, 64]], 'skip': ['backward'] }), ('TopK', { 'block': P.TopK(),