Exemplo n.º 1
0
def ssim(img1,
         img2,
         max_val,
         filter_size=11,
         filter_sigma=1.5,
         k1=0.01,
         k2=0.03):
    """
    ssim operator.
    
    Computes per-channel structural similarity.
    """
    filter_size = int(filter_size)

    N1, CH1, H1, W1 = img1.shape
    N2, CH2, H2, W2 = img2.shape

    if N1 != N2:
        raise ValueError('Images batch must match.')
    if CH1 != CH2:
        raise ValueError('Images channels must match.')

    kernel_key = (ssim, CH1, filter_size, filter_sigma)
    kernel = nc.Cacheton.get_var(kernel_key)
    if kernel is None:
        kernel = np.arange(0, filter_size, dtype=np.float32)
        kernel -= (filter_size - 1) / 2.0
        kernel = kernel**2
        kernel *= (-0.5 / (filter_sigma**2))
        kernel = np.reshape(kernel, (1, -1)) + np.reshape(kernel, (-1, 1))
        kernel_exp = np.exp(kernel)
        kernel = kernel_exp / kernel_exp.sum()
        kernel = kernel[None, ...]
        kernel = np.tile(kernel, (CH1, 1, 1))
        nc.Cacheton.set_var(kernel_key, kernel)

    kernel_t = nn.Tensor_from_value(kernel)
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2

    mean0 = nn.depthwise_conv2D(img1, kernel_t, stride=1, padding='valid')
    mean1 = nn.depthwise_conv2D(img2, kernel_t, stride=1, padding='valid')
    num0 = mean0 * mean1 * 2.0

    den0 = mean0 * mean0 + mean1 * mean1

    luminance = (num0 + c1) / (den0 + c1)

    num1 = nn.depthwise_conv2D(
        img1 * img2, kernel_t, stride=1, padding='valid') * 2.0
    den1 = nn.depthwise_conv2D(img1 * img1 + img2 * img2,
                               kernel_t,
                               stride=1,
                               padding='valid')

    cs = (num1 - num0 + c2) / (den1 - den0 + c2)

    return nn.reduce_mean(luminance * cs, axes=(-2, -1))
Exemplo n.º 2
0
    def forward(self, x):
        n, c, h, w = x.shape

        a = nn.tile(self.a, (c, 1, 1))

        x = nn.depthwise_conv2D(x, a, self.stride, self.dilation, self.padding)
        return x
Exemplo n.º 3
0
    def forward(self, x, **kwargs):
        x = nn.depthwise_conv2D(x,
                                self.kernel,
                                self.stride,
                                self.dilation,
                                padding=self.padding)

        if self.use_bias:
            x = x + self.bias.reshape((1, -1, 1, 1))
        return x
Exemplo n.º 4
0
def depthwise_conv2d_test():
    for padding in ['same','valid',0,1,2]:
        for dilation in [1,2]:
          for stride in [1,2,3]:
            for ks in [1,3,5,7]:
              for n in [1,4]:
                for ic in [1,2,4]:
                    for ih,iw in zip(*[[4,8,16]]*2):                        
                        if padding == 'valid' and iw < ks:
                            continue
                        try:
                            input_shape  = (n, ic, ih, iw)
                            kernel_shape = (ic, ks, ks)

                            input_n  = np.random.randint( 2**4, size=input_shape ).astype(np.float32)
                            kernel_n = np.random.randint( 2**4, size=kernel_shape ).astype(np.float32)

                            input_t  = nn.Tensor_from_value(input_n)
                            kernel_t = nn.Tensor_from_value(kernel_n)

                            conved_t = nn.depthwise_conv2D(input_t, kernel_t, stride=stride, dilation=dilation, padding=padding)
                            conved_n_grad = np.random.randint( 2**4, size=conved_t.shape).astype(np.float32)
                            conved_n, dI_val, dK_val = _numpy_depthwise_conv2d(input_n, kernel_n, conved_n_grad, STRIDE=stride, DILATION=dilation, padding=padding)

                            if conved_n.shape != conved_t.shape:
                                raise Exception(f'shape is not equal')

                            if not all ( np.ndarray.flatten( conved_t.np() == conved_n) ):
                                raise Exception(f'data is not equal')

                            input_t.get_grad().fill(1.0)
                            kernel_t.get_grad().fill(1.0)
                            nn.backward( {conved_t:conved_n_grad}, grad_for_non_trainables=True )

                            if not all ( np.ndarray.flatten( (input_t.get_grad().np()-1.0) == dI_val )):
                                raise Exception(f'dI not equal')

                            if not all ( np.ndarray.flatten( (kernel_t.get_grad().np()-1.0) == dK_val )):
                                raise Exception(f'dK not equal')
                        except:
                            raise Exception(f"""
input_shape   : {input_shape}
kernel_shape  : {kernel_shape}
padding       : {padding}
stride        : {stride}
dilation      : {dilation}
conved_n.shape : {conved_n.shape}
conved_t.shape : {conved_t.shape}
{traceback.format_exc()}
""")