コード例 #1
0
    def test(self):
        np.random.seed(seed)
        discriminator = wgan.Discriminator()
        discriminator.weight_clipping()

        assert_allclose(discriminator.W0, tests[2][0], atol=TOLERANCE)
        assert_allclose(discriminator.b0, tests[2][1], atol=TOLERANCE)
        assert_allclose(discriminator.W1, tests[2][2], atol=TOLERANCE)
        assert_allclose(discriminator.b1, tests[2][3], atol=TOLERANCE)
        assert_allclose(discriminator.W2, tests[2][4], atol=TOLERANCE)
        assert_allclose(discriminator.b2, tests[2][5], atol=TOLERANCE)
コード例 #2
0
    def test(self):
        np.random.seed(seed)
        discriminator = wgan.Discriminator()
        x = np.random.rand(64, 28, 28, 1)
        logit = discriminator.forward(x)
        # logit = np.random.rand(64, 1)
        discriminator.backward(logit, x, -1)

        assert_allclose(discriminator.W0, tests[1][0], atol=TOLERANCE)
        assert_allclose(discriminator.b0, tests[1][1], atol=TOLERANCE)
        assert_allclose(discriminator.W1, tests[1][2], atol=TOLERANCE)
        assert_allclose(discriminator.b1, tests[1][3], atol=TOLERANCE)
        assert_allclose(discriminator.W2, tests[1][4], atol=TOLERANCE)
        assert_allclose(discriminator.b2, tests[1][5], atol=TOLERANCE)
コード例 #3
0
    def test(self):
        np.random.seed(seed)
        generator = wgan.Generator()
        discriminator = wgan.Discriminator()

        z = np.random.rand(64, 100)
        fake_input = generator.forward(z)
        fake_logit = discriminator.forward(fake_input)

        generator.backward(fake_logit, fake_input, discriminator)

        assert_allclose(generator.W0, tests[4][0], atol=TOLERANCE)
        assert_allclose(generator.b0, tests[4][1], atol=TOLERANCE)
        assert_allclose(generator.W1, tests[4][2], atol=TOLERANCE)
        assert_allclose(generator.b1, tests[4][3], atol=TOLERANCE)
        assert_allclose(generator.W2, tests[4][4], atol=TOLERANCE)
        assert_allclose(generator.b2, tests[4][5], atol=TOLERANCE)
        assert_allclose(generator.W3, tests[4][6], atol=TOLERANCE)
        assert_allclose(generator.b3, tests[4][7], atol=TOLERANCE)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU device ID')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='# of epoch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--g_hidden', type=int, default=128)
    parser.add_argument('--g_arch', type=int, default=1)
    parser.add_argument('--g_activate', type=str, default='sigmoid')
    parser.add_argument('--g_channel', type=int, default=512)
    parser.add_argument('--d_arch', type=int, default=1)
    parser.add_argument('--d_iters', type=int, default=5)
    parser.add_argument('--d_clip', type=float, default=0.01)
    parser.add_argument('--d_use_gamma', action='store_true', default=False)
    parser.add_argument('--d_channel', type=int, default=512)
    parser.add_argument('--initial_iter', type=int, default=10)
    parser.add_argument('--resume', default='')
    parser.add_argument('--out', default='')
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    # log directory
    out = datetime.datetime.now().strftime('%m%d%H%M')
    if args.out:
        out = out + '_' + args.out
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out))
    os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True)

    # hyper parameter
    with open(os.path.join(out_dir, 'setting.txt'), 'w') as f:
        for k, v in args._get_kwargs():
            print('{} = {}'.format(k, v))
            f.write('{} = {}\n'.format(k, v))

    # tensorboard
    if use_tensorboard:
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        summary_dir = os.path.join(out_dir, "summaries")
        loss_ = tf.placeholder(tf.float32)
        gen_loss_summary = tf.scalar_summary('gen_loss', loss_)
        dis_loss_summary = tf.scalar_summary('dis_loss', loss_)
        summary_writer = tf.train.SummaryWriter(summary_dir, sess.graph)

    # load celebA
    dataset = CelebA(image_type=args.g_activate)
    train_iter = chainer.iterators.MultiprocessIterator(
        dataset, args.batch_size)

    if args.g_arch == 1:
        gen = wgan.Generator(n_hidden=args.g_hidden,
                             activate=args.g_activate,
                             ch=args.g_channel)
    elif args.g_arch == 2:
        gen = wgan.Generator2(n_hidden=args.g_hidden,
                              activate=args.g_activate,
                              ch=args.g_channel)
    else:
        raise ValueError('invalid arch')

    if args.d_arch == 1:
        dis = wgan.Discriminator(ch=args.d_channel, use_gamma=args.d_use_gamma)
    elif args.d_arch == 2:
        dis = wgan.Discriminator2(ch=args.d_channel)
    else:
        raise ValueError('invalid arch')

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    optimizer_gen = chainer.optimizers.RMSprop(lr=0.00005)
    optimizer_dis = chainer.optimizers.RMSprop(lr=0.00005)

    optimizer_gen.setup(gen)
    optimizer_dis.setup(dis)

    # optimizer_gen.add_hook(chainer.optimizer.WeightDecay(0.00001))
    # optimizer_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))

    # start training
    start = time.time()
    gen_iterations = 0
    for epoch in range(args.epoch):

        # train
        sum_L_gen = []
        sum_L_dis = []

        i = 0
        while i < len(dataset) // args.batch_size:

            # tips for critic reach optimality
            if gen_iterations < args.initial_iter or gen_iterations % 500 == 0:
                d_iters = 100
            else:
                d_iters = args.d_iters

            ############################
            # (1) Update D network
            ###########################
            j = 0
            while j < d_iters:
                batch = train_iter.next()
                x = chainer.Variable(
                    gen.xp.asarray([b[0] for b in batch], 'float32'))
                # attr = chainer.Variable(gen.xp.asarray([b[1] for b in batch], 'int32'))

                # real
                y_real = dis(x)

                # fake
                z = chainer.Variable(gen.xp.asarray(
                    gen.make_hidden(args.batch_size)),
                                     volatile=True)
                x_fake = gen(z)
                x_fake.volatile = False
                y_fake = dis(x_fake)

                # calc EMD
                L_dis = -(y_real - y_fake)
                dis.cleargrads()
                L_dis.backward()
                optimizer_dis.update()

                dis.clip_weight(clip=args.d_clip)

                j += 1
                i += 1

            ###########################
            # (2) Update G network
            ###########################
            z = chainer.Variable(
                gen.xp.asarray(gen.make_hidden(args.batch_size)))
            x_fake = gen(z)
            y_fake = dis(x_fake)
            L_gen = -y_fake
            gen.cleargrads()
            L_gen.backward()
            optimizer_gen.update()

            gen_iterations += 1

            emd = float(-L_dis.data)
            l_gen = float(L_gen.data)
            sum_L_dis.append(emd)
            sum_L_gen.append(l_gen)

            progress_report(epoch * len(dataset) // args.batch_size + i, start,
                            args.batch_size, emd)

            if use_tensorboard:
                summary = sess.run(gen_loss_summary, feed_dict={loss_: l_gen})
                summary_writer.add_summary(summary, gen_iterations)
                summary = sess.run(dis_loss_summary, feed_dict={loss_: emd})
                summary_writer.add_summary(summary, gen_iterations)

        log = 'gen loss={:.5f}, dis loss={:.5f}'.format(
            np.mean(sum_L_gen), np.mean(sum_L_dis))
        print('\n' + log)
        with open(os.path.join(out_dir, "log"), 'a+') as f:
            f.write(log + '\n')

        if epoch % 5 == 0:
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.dis.model".format(epoch)), dis)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.gen.model".format(epoch)), gen)

        visualize(gen,
                  epoch=epoch,
                  savedir=os.path.join(out_dir, 'visualize'),
                  image_type=args.g_activate)
コード例 #5
0
 def test(self):
     np.random.seed(seed)
     discriminator = wgan.Discriminator()
     x = np.random.rand(64, 28, 28, 1)
     assert_allclose(discriminator.forward(x), tests[0], atol=TOLERANCE)
コード例 #6
0
ファイル: inpaint.py プロジェクト: LetheField/ganinpaint
            root=opt.dataroot,
            transform=transforms.Compose([
                transforms.CenterCrop((218 - fsize // 2 * 2,
                                       178 - fsize // 2 * 2)),  # Crop the edge
                transforms.Resize(opt.imageSize),
                transforms.CenterCrop(opt.imageSize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batchSize,
                                                 shuffle=True,
                                                 num_workers=int(opt.workers))
        # wGAN
        netD = wgan.Discriminator(opt.imageSize, opt.nc, opt.ndf, opt.ngpu,
                                  opt.n_extra_layers)
        netG = wgan.Generator(opt.imageSize, opt.nz, opt.nc, opt.ngf, opt.ngpu,
                              opt.n_extra_layers)

        netD.apply(weight_init)
        netG.apply(weight_init)

        if opt.pretrained:
            netD.load_state_dict(torch.load(opt.pre_D))
            netG.load_state_dict(torch.load(opt.pre_G))

        # device choice
        netD.to(device)
        netG.to(device)
        if (device.type == 'cuda') and (opt.ngpu > 1):
            netD = nn.DataParallel(netD, list(range(opt.ngpu)))