def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(MSSSIM, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive( 'filter_sigma', filter_sigma, self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) self.level = len(power_factors) self.conv = [] for i in range(self.level): self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) self.conv[i].weight.requires_grad = False self.multi_convs_list = CellList(self.conv) self.weight_tensor = Tensor(power_factors, mstype.float32) self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') self.relu = ReLU() self.reduce_mean = P.ReduceMean() self.prod = P.ReduceProd() self.pow = P.Pow() self.pack = P.Pack(axis=-1) self.concat = P.Concat(axis=1)
'block': ArgminNet(), 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], 'desc_bprop': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], 'skip': ['backward']}), ('CumSumNet', { 'block': CumSumNet(), 'desc_const': [0], 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))], 'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))]}), ('OneHot', { 'block': P.OneHot(), 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], 'desc_inputs': [Tensor(np.array([64]).astype(np.int32))], 'desc_bprop': [[64, 2]]}), ('ReduceProd_0', { 'block': P.ReduceProd(), 'desc_const': [0], 'desc_inputs': [[3, 2]], 'desc_bprop': [[2]]}), ('ReduceProd_1', { 'block': P.ReduceProd(keep_dims=True), 'desc_const': [0], 'desc_inputs': [[3, 2]], 'desc_bprop': [[1, 2]]}), ('CumProd', { 'block': P.CumProd(), 'desc_const': [0], 'desc_inputs': [[3, 2]], 'desc_bprop': [[3, 2]]}), ('ApplyFtrl', { 'block': P.ApplyFtrl(),