Exemple #1
0
 ('LayerNorm', {
     'block': P.LayerNorm(),
     'desc_inputs': [[2, 16], [16], [16]],
     'desc_bprop': [[2, 16], [2, 16], [2, 16]]}),
 ('LayerNormGrad', {
     'block': G.LayerNormGrad(),
     'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
     'desc_bprop': [[2, 16], [16], [16]],
     'skip': ['backward']}),
 ('FusedBatchNorm', {
     'block': P.FusedBatchNorm(),
     'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'skip': []}),
 ('FusedBatchNormGrad', {
     'block': G.FusedBatchNormGrad(),
     'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
     'skip': ['backward']}),
 ('BatchNorm', {
     'block': P.BatchNorm(),
     'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
     'skip': []}),
 ('BatchNormGrad', {
     'block': G.BatchNormGrad(),
     'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
     'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
     'skip': ['backward']}),
 ('ApplyMomentum', {
     'block': P.ApplyMomentum(),
tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd()
allreduce = P.AllReduce()
allreduce.add_prim_attr('fusion', 1)
make_tuple = Primitive('make_tuple')
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
bn = P.FusedBatchNorm()
relu = P.ReLU()
conv_bn1 = Primitive('ConvBN1')
bn2_add_relu = Primitive('BN2AddRelu')
bn2_relu = Primitive('BN2Relu')
fused_bn1 = Primitive('FusedBN1')
fused_bn2 = Primitive('FusedBN2')
fused_bn3 = Primitive('FusedBN3')
bn_grad = G.FusedBatchNormGrad()
bn_grad1 = Primitive('BNGrad1')
bn_grad2 = Primitive('BNGrad2')
bn_grad3 = Primitive('BNGrad3')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]