Exemple #1
0
 def update_g(self, optimizer):
     noisev = Variable(
         np.asarray(np.random.normal(size=(self.args.batch_size, self.args.nz, 1, 1)), dtype=np.float32))
     noisev.to_device(self.device)
     fake = self.G(noisev)
     errG = self.D(fake)
     optimizer.update(self.g_loss, errG)
Exemple #2
0
 def _(bm):
     y = Variable(numpy.random.rand(*shape))
     t = Variable(numpy.random.rand(*shape))
     y.to_device(0)
     t.to_device(0)
     with bm:
         ssim_depthwise_convolution(y, t, 11, 1)
Exemple #3
0
 def _(bm):
     y = Variable(numpy.random.rand(*shape))
     t = Variable(numpy.random.rand(*shape))
     y.to_device(0)
     t.to_device(0)
     with bm:
         ssim_im2col(y, t, 11, 1)
Exemple #4
0
    def update_d(self, optimizer):
        batch = self.get_iterator('main').next()
        inputv = Variable(self.converter(batch, self.device))
        errD_real = self.D(inputv)

        # train with fake
        noisev = Variable(
            np.asarray(np.random.normal(size=(self.args.batch_size, self.args.nz, 1, 1)), dtype=np.float32))
        noisev.to_device(self.device)
        fake = self.G(noisev)
        errD_fake = self.D(fake)

        optimizer.update(self.d_loss, errD_real, errD_fake)
Exemple #5
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser("Maximize SSIM")
    parser.add_argument("--device", type=int, default=-1)
    parser.add_argument("--noplot", dest="is_plot", action="store_false")

    args = parser.parse_args()

    device = chainer.get_device(args.device)

    img1 = cv2.imread("assets/einstein.png")

    img1 = img1.astype(np.float32).transpose(2, 0, 1) / 255.0
    img1 = np.expand_dims(img1, 0)
    img1 = Variable(img1)
    img1.to_device(device)

    img2 = L.Parameter(np.random.rand(*img1.shape).astype(np.float32))

    
    img2.to_device(device)
    optimizer = Adam(0.1)
    optimizer.setup(img2)
    device.use()

    print(type(img1), type(img2()))
    ssim_value = ssim_loss(img1, img2(), 11, 11)
    print("Initial ssim:", ssim_value)

    step = 1
    while ssim_value.data < 0.95: