Exemple #1
0
def psnr(pred, truth):
    """
    https://qiita.com/yoya/items/510043d836c9f2f0fe2f
    """
    batch_size = len(pred)
    mse = F.sum((pred - truth).reshape(batch_size, -1)**2, axis=1)
    mse = F.clip(mse, 1e-8, 1e+8)
    max_i = F.max(truth.reshape(batch_size, -1), axis=1)
    max_i = F.clip(max_i, 1e-8, 1e+8)
    return 20 * F.log10(max_i) - 10 * F.log10(mse)
Exemple #2
0
 def lf(x, t):
     h1, h2, h3, h4 = self.encode(x)
     y = self.decode(h1, h2, h3, h4)
     mse_loss = F.mean_squared_error(y, t)
     with chainer.using_config('train',
                               False), chainer.no_backprop_mode():
         y_content = self.vgg16(y)
         t_content = self.vgg16(t)
     cont_loss = 0.0
     k = [0.9, 0.7, 0.5, 0.3, 0.1]
     for i in range(len(y_content)):
         cont_loss += k[i] * F.mean_squared_error(
             y_content[i], t_content[i])
     self.mse_loss = mse_loss
     self.cont_loss = cont_loss
     self.loss = mse_loss + C * self.cont_loss
     self.psnr = 10 * F.log10(1.0 / mse_loss)
     chainer.report(
         {
             'mse_loss': self.mse_loss,
             'cont_loss': self.cont_loss,
             'loss': self.loss,
             'psnr': self.psnr
         },
         observer=self)
     return self.loss
Exemple #3
0
    def __call__(self, x, t):
        self.y = None
        self.loss = None
        self.psnr = None
        self.y = self.generator(x)
        loss_mse = F.mean_squared_error(self.y, t)
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                y_cont = self.vgg16(self.y)
                t_cont = self.vgg16(t)
        loss_cont = None
        for i in range(len(y_cont)):
            if loss_cont is None:
                loss_cont = F.mean_squared_error(y_cont[i], t_cont[i])
            else:
                loss_cont += F.mean_squared_error(y_cont[i], t_cont[i])

        self.loss = 10 * loss_cont + loss_mse
        self.psnr = 10 * F.log10(1.0 / loss_mse)
        reporter.report(
            {
                'loss': self.loss,
                'loss_cont': loss_cont,
                'loss_mse': loss_mse,
                'PSNR': self.psnr
            }, self)
        return self.loss
Exemple #4
0
 def __call__(self, x, t):
     self.y = None
     self.loss = None
     self.psnr = None
     self.y = self.generator(x)
     self.loss = F.mean_squared_error(self.y, t)
     self.psnr = 10 * F.log10(1.0 / self.loss)
     reporter.report({'loss': self.loss, 'PSNR': self.psnr}, self)
     return self.loss
Exemple #5
0
 def lf(x, t):
     y = self(x)
     loss = F.mean_squared_error(y, t)
     self.loss = loss
     self.psnr = 10 * F.log10(1.0 / loss)
     chainer.report({
         'loss': self.loss,
         'psnr': self.psnr
     },
                    observer=self)
     return self.loss
Exemple #6
0
 def lf(x, t):
     h1, h2, h3, h4 = self.encode(x)
     y = self.decode(h1, h2, h3, h4)
     loss = F.mean_squared_error(y, t)
     self.loss = loss
     self.psnr = 10 * F.log10(1.0 / loss)
     chainer.report({
         'loss': self.loss,
         'psnr': self.psnr
     },
                    observer=self)
     return self.loss
Exemple #7
0
 def log10(self, x):
     return F.log10(x)
 def sid(self, y, t):
     err = F.absolute(F.sum(y * F.log10((y + self.eps) / (t + self.eps))) + \
                      F.sum(t * F.log10((t + self.eps) / (y + self.eps))))
     return err * 0.041  # = err * 2** 16 / (14 * 240 * 480)