예제 #1
0
파일: loss.py 프로젝트: victorca25/BasicSR
    def __call__(self, input, reference):
        ## Use "spl_denorm" when reading a [-1,1] input, but you want to compute the loss over a [0,1] range
        # self.spl_denorm=False when your inputs and outputs are in [0,1] range already
        # Note: only rgb_to_yuv() requires image in the [0,1], so this denorm is optional, depending on the net
        if self.spl_denorm:
            input = denorm(input)
            reference = denorm(reference)
        total_loss = 0
        if self.rgb:
            total_loss += self.trace(input, reference)
        if self.yuv:
            # rgb_to_yuv() needs images in [0,1] range to work
            if not self.spl_denorm and self.yuv_denorm:
                input = denorm(input)
                reference = denorm(reference)
            input_yuv = rgb_to_yuv(input)
            reference_yuv = rgb_to_yuv(reference)
            total_loss += self.trace(input_yuv, reference_yuv)
        if self.yuvgrad:
            input_h, input_v = get_image_gradients(input_yuv)
            ref_h, ref_v = get_image_gradients(reference_yuv)

            total_loss += self.trace(input_v, ref_v)
            total_loss += self.trace(input_h, ref_h)

        return total_loss
예제 #2
0
    def __call__(self, x, y):
        """
        Args:
            x: input image batch.
            y: reference image batch.
        """
        if self.spl_denorm:
            x = denorm(x)
            y = denorm(y)
        total_loss = 0
        if self.rgb:
            total_loss += self.trace(x, y)
        if self.yuv:
            # rgb_to_yuv() needs images in [0,1] range to work
            if not self.spl_denorm and self.yuv_denorm:
                x = denorm(x)
                y = denorm(y)
            input_yuv = rgb_to_yuv(x)
            reference_yuv = rgb_to_yuv(y)
            total_loss += self.trace(input_yuv, reference_yuv)
        if self.yuvgrad:
            input_h, input_v = get_image_gradients(input_yuv)
            ref_h, ref_v = get_image_gradients(reference_yuv)

            total_loss += self.trace(input_v, ref_v)
            total_loss += self.trace(input_h, ref_h)

        return total_loss
예제 #3
0
파일: loss.py 프로젝트: victorca25/BasicSR
    def __call__(self, input, reference):
        ## Use "spl_denorm" when reading a [-1,1] input, but you want to compute the loss over a [0,1] range
        # Note: only rgb_to_yuv() requires image in the [0,1], so this denorm is optional, depending on the net
        if self.spl_denorm == True:
            input = denorm(input)
            reference = denorm(reference)
        input_h, input_v = get_image_gradients(input)
        ref_h, ref_v = get_image_gradients(reference)

        trace_v = self.trace(input_v, ref_v)
        trace_h = self.trace(input_h, ref_h)
        return trace_v + trace_h
예제 #4
0
    def __call__(self, x, y):
        """
        Args:
            x: input image batch.
            y: reference image batch.
        """
        if self.spl_denorm:
            x = denorm(x)
            y = denorm(y)
        input_h, input_v = get_image_gradients(x)
        ref_h, ref_v = get_image_gradients(y)

        trace_v = self.trace(input_v, ref_v)
        trace_h = self.trace(input_h, ref_h)
        return trace_v + trace_h
예제 #5
0
 def forward(self, img):
     if self.type == 'separator':
         if self.recursions > 1:
             for i in range(self.recursions - 1):
                 img = self.filter_low(img)
         img = img - self.filter_low(img)
     elif self.type == 'independent':
         img = self.filter_low(img)
     if self.normalize:
         return denorm(img)
     else:
         return img