def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(MSSSIM, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive( 'filter_sigma', filter_sigma, self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) self.level = len(power_factors) self.conv = [] for i in range(self.level): self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) self.conv[i].weight.requires_grad = False self.multi_convs_list = CellList(self.conv) self.weight_tensor = Tensor(power_factors, mstype.float32) self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') self.relu = ReLU() self.reduce_mean = P.ReduceMean() self.prod = P.ReduceProd() self.pow = P.Pow() self.pack = P.Pack(axis=-1) self.concat = P.Concat(axis=1)
def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): """Check param.""" validator.check_value_type("initial_accum", initial_accum, [float], prim_name) validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) validator.check_value_type("lr_power", lr_power, [float], prim_name) validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) validator.check_value_type("l1", l1, [float], prim_name) validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) validator.check_value_type("l2", l2, [float], prim_name) validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) validator.check_value_type("use_locking", use_locking, [bool], prim_name) validator.check_value_type("weight_decay", weight_decay, [float], prim_name) validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
def __init__(self, clip_norm=1.0, use_norm=None): super(_ClipByGlobalNorm, self).__init__() # Add interface. This parameter is not used at present if use_norm is not None: validator.check_number("use_norm", use_norm, 0.0, Rel.GE, self.cls_name) validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name) self.clip_norm = Tensor([clip_norm], mstype.float32) self.hyper_map = C.HyperMap() self.greater_equal = P.GreaterEqual()
def __init__(self, power=0, name='PowerTransform', param=None): param = dict(locals()) if param is None else param super(PowerTransform, self).__init__(name=name, param=param) validator.check_value_type('power', power, [int, float], self.name) validator.check_number("power", power, 0, Rel.GE, self.name) self._power = power self.pow = P.Pow() self.exp = exp_generic self.expm1 = expm1_generic self.log = log_generic self.log1p = log1p_generic
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(SSIM, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) validator.check_value_type('k1', k1, [float], self.cls_name) self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) validator.check_value_type('k2', k2, [float], self.cls_name) self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(SSIM, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) self.conv = _conv2d(1, 1, filter_size, Tensor(window)) self.conv.weight.requires_grad = False self.reduce_mean = P.ReduceMean() self.concat = P.Concat(axis=1)
def __init__(self, max_val=1.0): super(PSNR, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val
def _check_value(clip_norm): validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, "clip_by_global_norm") return clip_norm
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, prim_name=None): validator.check_value_type("initial_accum", initial_accum, [float], prim_name) validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) validator.check_value_type("learning_rate", learning_rate, [float], prim_name) validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) validator.check_value_type("lr_power", lr_power, [float], prim_name) validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) validator.check_value_type("l1", l1, [float], prim_name) validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) validator.check_value_type("l2", l2, [float], prim_name) validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) validator.check_value_type("use_locking", use_locking, [bool], prim_name) validator.check_value_type("loss_scale", loss_scale, [float], prim_name) validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) validator.check_value_type("weight_decay", weight_decay, [float], prim_name) validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)