Beispiel #1
0
     'desc_bprop': [[128 * 32 * 8 * 16]]}),
 ('LogSoftmax', {
     'block': P.LogSoftmax(),
     'desc_inputs': [[64, 2]],
     'desc_bprop': [[160, 30522]]}),
 ('LogSoftmaxGrad', {
     'block': G.LogSoftmaxGrad(),
     'desc_inputs': [[16, 1234], [16, 1234]],
     'desc_bprop': [[64, 2]],
     'skip': ['backward']}),
 ('LayerNorm', {
     'block': P.LayerNorm(),
     'desc_inputs': [[2, 16], [16], [16]],
     'desc_bprop': [[2, 16], [2, 16], [2, 16]]}),
 ('LayerNormGrad', {
     'block': G.LayerNormGrad(),
     'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
     'desc_bprop': [[2, 16], [16], [16]],
     'skip': ['backward']}),
 ('FusedBatchNorm', {
     'block': P.FusedBatchNorm(),
     'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'skip': []}),
 ('FusedBatchNormGrad', {
     'block': G.FusedBatchNormGrad(),
     'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'skip': ['backward']}),
 ('BatchNorm', {
     'block': P.BatchNorm(),
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
layer_norm_grad = G.LayerNormGrad()
layer_norm_x_backprop = Primitive('LayerNormXBackprop')
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]

Beispiel #3
0
 def __init__(self, begin_norm_axis, begin_params_axis):
     super(LayerNormGradNet, self).__init__()
     self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)