示例#1
0
 def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
     """init batch norm fold layer"""
     super(BatchNormFoldCell, self).__init__()
     self.epsilon = epsilon
     self.is_gpu = context.get_context('device_target') == "GPU"
     if self.is_gpu:
         self.bn_train = P.BatchNormFold(momentum,
                                         epsilon,
                                         is_training=True,
                                         freeze_bn=freeze_bn)
         self.bn_infer = P.BatchNormFold(momentum,
                                         epsilon,
                                         is_training=False,
                                         freeze_bn=freeze_bn)
     else:
         self.bn_reduce = P.BNTrainingReduce()
         self.bn_update = P.BatchNormFoldD(momentum,
                                           epsilon,
                                           is_training=True,
                                           freeze_bn=freeze_bn)
 def __init__(self):
     super(Net, self).__init__()
     self.op = P.BatchNormFold(freeze_bn=10)
示例#3
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 pad_mode='same',
                 padding=0,
                 dilation=1,
                 group=1,
                 eps=1e-5,
                 momentum=0.997,
                 weight_init=None,
                 beta_init=None,
                 gamma_init=None,
                 mean_init=None,
                 var_init=None,
                 quant_delay=0,
                 freeze_bn=100000,
                 fake=True,
                 num_bits=8,
                 per_channel=False,
                 symmetric=False,
                 narrow_range=False):
        super(Conv2dBatchNormQuant, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pad_mode = pad_mode
        self.padding = padding
        self.dilation = twice(dilation)
        self.stride = twice(stride)
        self.group = group
        self.fake = fake
        self.freeze_bn = freeze_bn
        self.momentum = momentum
        self.quant_delay = quant_delay
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size)
        else:
            self.kernel_size = kernel_size

        if weight_init is None:
            weight_init = initializer(
                'normal',
                [out_channels, in_channels // group, *self.kernel_size])
        self.weight = Parameter(weight_init, name='weight')
        if gamma_init is None:
            gamma_init = initializer('ones', [out_channels])
        self.gamma = Parameter(gamma_init, name='gamma')
        if beta_init is None:
            beta_init = initializer('zeros', [out_channels])
        self.beta = Parameter(beta_init, name='beta')
        if mean_init is None:
            mean_init = initializer('zeros', [out_channels])
        self.moving_mean = Parameter(mean_init,
                                     name='moving_mean',
                                     requires_grad=False)
        if var_init is None:
            var_init = initializer('ones', [out_channels])
        self.moving_variance = Parameter(var_init,
                                         name='moving_variance',
                                         requires_grad=False)

        self.step = Parameter(initializer('normal', [1], dtype=mstype.int32),
                              name='step',
                              requires_grad=False)

        self.conv = P.Conv2D(out_channel=self.out_channels,
                             kernel_size=self.kernel_size,
                             mode=1,
                             pad_mode=self.pad_mode,
                             pad=self.padding,
                             stride=self.stride,
                             dilation=self.dilation,
                             group=self.group)
        self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
                                                     max_init=6,
                                                     ema=False,
                                                     num_bits=num_bits,
                                                     quant_delay=quant_delay,
                                                     per_channel=per_channel,
                                                     out_channels=out_channels,
                                                     symmetric=symmetric,
                                                     narrow_range=narrow_range)
        self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
                                                    momentum=momentum,
                                                    is_training=True,
                                                    freeze_bn=freeze_bn)
        self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps,
                                                    momentum=momentum,
                                                    is_training=False,
                                                    freeze_bn=freeze_bn)
        self.correct_mul = P.CorrectionMul()
        self.relu = P.ReLU()
        self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn)
        self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
        self.one = Tensor(1, mstype.int32)
        self.assignadd = P.AssignAdd()