Esempio n. 1
0
    def image_test(real_A,
                   mask_A,
                   diff_A,
                   real_B,
                   mask_B,
                   diff_B,
                   shade_alpha=1):
        G = net.Generator()
        G.load_state_dict(
            torch.load(pwd + '/pretrained_models/G.pth',
                       map_location=torch.device('cpu')))
        G.eval()

        cur_prama = None
        with torch.no_grad():
            cur_prama = Solver_PSGAN.generate(real_A,
                                              real_B,
                                              None,
                                              None,
                                              mask_A,
                                              mask_B,
                                              diff_A,
                                              diff_B,
                                              ret=True,
                                              generator=G,
                                              mode='test')
            cur_prama_source = Solver_PSGAN.generate(real_A,
                                                     real_A,
                                                     None,
                                                     None,
                                                     mask_A,
                                                     mask_A,
                                                     diff_A,
                                                     diff_A,
                                                     ret=True,
                                                     generator=G,
                                                     mode='test')
            shade_gamma = cur_prama[0] * shade_alpha + cur_prama_source[0] * (
                1 - shade_alpha)
            shade_beta = cur_prama[1] * shade_alpha + cur_prama_source[1] * (
                1 - shade_alpha)
            fake_A = Solver_PSGAN.generate(real_A,
                                           real_B,
                                           None,
                                           None,
                                           mask_A,
                                           mask_B,
                                           diff_A,
                                           diff_B,
                                           gamma=shade_gamma,
                                           beta=shade_beta,
                                           generator=G,
                                           mode='test')
        fake_A = data2img(fake_A)
        return fake_A
Esempio n. 2
0
    def build_model(self):
        # Define generators and discriminators
        self.G_A = net.Generator(self.g_conv_dim, self.g_repeat_num)
        self.G_B = net.Generator(self.g_conv_dim, self.g_repeat_num)
        self.D_A = net.Discriminator(self.img_size, self.d_conv_dim,
                                     self.d_repeat_num)
        self.D_B = net.Discriminator(self.img_size, self.d_conv_dim,
                                     self.d_repeat_num)
        self.criterionL1 = torch.nn.L1Loss()
        self.criterionGAN = GANLoss(use_lsgan=True,
                                    tensor=torch.cuda.FloatTensor)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(
            itertools.chain(self.G_A.parameters(), self.G_B.parameters()),
            self.g_lr, [self.beta1, self.beta2])
        self.d_A_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D_A.parameters()),
            self.d_lr, [self.beta1, self.beta2])
        self.d_B_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D_B.parameters()),
            self.d_lr, [self.beta1, self.beta2])

        self.G_A.apply(self.weights_init_xavier)
        self.D_A.apply(self.weights_init_xavier)
        self.G_B.apply(self.weights_init_xavier)
        self.D_B.apply(self.weights_init_xavier)

        # Print networks
        #  self.print_network(self.E, 'E')
        self.print_network(self.G_A, 'G_A')
        self.print_network(self.D_A, 'D_A')
        self.print_network(self.G_B, 'G_B')
        self.print_network(self.D_B, 'D_B')

        if torch.cuda.is_available():
            self.G_A.cuda()
            self.G_B.cuda()
            self.D_A.cuda()
            self.D_B.cuda()
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser(description='Generate mnist image')
    parser.add_argument('--genpath', type=str,
                        help='path to a trained generator')
    parser.add_argument('--dimz', '-z', type=int, default=20,
                        help='dimention of encoded vector')
    parser.add_argument('--out', '-o', type=str, default='result',
                        help='path to the output directory')
    args = parser.parse_args()

    if not os.path.exists(args.out):
        os.makedirs(args.out)

    print(args)


    gen = net.Generator(784, args.dimz, 500)
    chainer.serializers.load_npz(args.genpath, gen)
    print('Generator model loaded successfully from {}'.format(args.genpath))

    generate_image(gen, 10, 10, 0, args.out + '/image.png')
Esempio n. 4
0
    def build_model(self):
        # Define generators and discriminators
        self.G = net.Generator()
        for i in self.cls:
            setattr(
                self, "D_" + i,
                net.Discriminator(self.img_size, self.d_conv_dim,
                                  self.d_repeat_num, self.norm))

        self.criterionL1 = torch.nn.L1Loss()
        self.criterionL2 = torch.nn.MSELoss()
        self.criterionGAN = GANLoss(use_lsgan=True,
                                    tensor=torch.cuda.FloatTensor)
        self.vgg = models.vgg16(pretrained=True)
        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        for i in self.cls:
            setattr(self, "d_" + i + "_optimizer",
                    torch.optim.Adam(filter(lambda p: p.requires_grad, getattr(self, "D_" + i).parameters()), \
                                     self.d_lr, [self.beta1, self.beta2]))

        # Weights initialization
        self.G.apply(self.weights_init_xavier)
        for i in self.cls:
            getattr(self, "D_" + i).apply(self.weights_init_xavier)

        # Print networks
        self.print_network(self.G, 'G')
        for i in self.cls:
            self.print_network(getattr(self, "D_" + i), "D_" + i)

        if torch.cuda.is_available():
            self.G.cuda()
            self.vgg.cuda()
            for i in self.cls:
                getattr(self, "D_" + i).cuda()
Esempio n. 5
0
def main():
    args = easydict.EasyDict({
        # "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data/gan",
        "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data",
        "save_dir": "./",
        "prefix": "test",
        "workers": 8,
        "batch_size": 128,
        "image_size": 32,
        # "image_size": 28,
        # "nc": 3,
        "nc": 1,
        "nz": 100,
        "ngf": 32,
        "ndf": 32,
        # "ngf": 28,
        # "ndf": 64,
        "epochs": 1,
        "lr": 0.0002,
        "beta1": 0.5,
        "gpu": 7,
        "use_cuda": True,
        "feature_matching": True,
        "mini_batch": True,
        "iters": 50000,
        "label_batch_size": 100,
        "unlabel_batch_size": 100,
        "test_batch_size": 10,
        "out_dir": './result',
        "log_interval": 500,
        "label_num": 20
    })

    manualSeed = 999
    np.random.seed(manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    device = torch.device(
        'cuda:{}'.format(args.gpu) if args.use_cuda else 'cpu')

    # transform = transforms.Compose([
    #     transforms.Resize(args.image_size),
    #     transforms.CenterCrop(args.image_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    # ])
    #
    # dataset = dset.ImageFolder(root=args.dataroot, transform=transform)
    # dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
    #                                          shuffle=True, num_workers=args.workers)

    data_iterators = dataset.get_iters(root_path=args.dataroot,
                                       l_batch_size=args.label_batch_size,
                                       ul_batch_size=args.unlabel_batch_size,
                                       test_batch_size=args.test_batch_size,
                                       workers=args.workers,
                                       n_labeled=args.label_num)

    trainloader_label = data_iterators['labeled']
    trainloader_unlabel = data_iterators['unlabeled']
    testloader = data_iterators['test']

    # Generator用のモデルのインスタンス作成
    netG = net.Generator(args.nz, args.ngf, args.nc).to(device)
    # Generator用のモデルの初期値を設定
    netG.apply(net.weights_init)

    # Discriminator用のモデルのインスタンス作成
    netD = net.Discriminator(args.nc, args.ndf, device, args.batch_size,
                             args.mini_batch).to(device)
    # Discriminator用のモデルの初期値を設定
    netD.apply(net.weights_init)

    # BCE Loss classのインスタンスを作成
    criterionD = nn.CrossEntropyLoss()
    # criterionD = nn.BCELoss()

    if args.feature_matching is True:
        criterionG = nn.MSELoss(reduction='elementwise_mean')
    else:
        criterionG = nn.BCELoss()

    # Generatorに入力するノイズをバッチごとに作成 (バッチ数は64)
    # これはGeneratorの結果を描画するために使用する
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # 最適化関数のインスタンスを作成
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    logger = TrainLogger(args)
    r = run.NNRun(netD, netG, optimizerD, optimizerG, criterionD, criterionG,
                  device, fixed_noise, logger, args)

    # 学習
    # r.train(dataloader)
    r.train(trainloader_label, trainloader_unlabel, testloader)
Esempio n. 6
0
import numpy as np
from chainer import Variable
import chainer.computational_graph as c

import net

gen = net.Generator(784, 20, 500)
dis = net.Discriminator(784, 500)

x_real = np.empty((1, 784), dtype=np.float32)
z = Variable(np.asarray(gen.make_hidden(1)))

y_real = dis(x_real)
x_fake = gen(z)
y_fake = dis(x_fake)

g = c.build_computational_graph([y_real, x_fake, y_fake])
with open('graph.dot', 'w') as o:
    o.write(g.dump())
Esempio n. 7
0
train_LSTM_using_cached_features = False
train_lstm_prob = .5
train_dis = True
image_save_interval = 200000
model_save_interval = image_save_interval
out_image_row_num = 7
out_image_col_num = 14
if train_LSTM:
    BATCH_SIZE /= seq_length
normer = args.size * args.size * 3 * 60
image_path = ['../../../trajectories/al5d']
np.random.seed(1241)
image_size = args.size
enc_model = [net.Encoder(density=8, size=image_size, latent_size=latent_size)]
gen_model = [
    net.Generator(density=8, size=image_size, latent_size=latent_size)
]
dis_model = [net.Discriminator(density=8, size=image_size)]
for i in range(num_gpus - 1):
    enc_model.append(copy.deepcopy(enc_model[0]))
    gen_model.append(copy.deepcopy(gen_model[0]))
    dis_model.append(copy.deepcopy(dis_model[0]))

enc_dis_model = net.Encoder(density=8,
                            size=image_size,
                            latent_size=latent_size)
gen_dis_model = net.Generator(density=8,
                              size=image_size,
                              latent_size=latent_size)
rnn_model = rnn_net.MDN_RNN(IN_DIM=latent_size + 5,
                            HIDDEN_DIM=300,
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser(description='LSGAN')

    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=20,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')

    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')

    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=100,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--image_size',
                        type=int,
                        default=64,
                        help='Size of output image')

    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console (iter)')
    parser.add_argument('--preview_interval',
                        type=int,
                        default=1,
                        help='Interval of preview (epoch)')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10,
                        help='Interval of snapshot (epoch)')

    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# n_hidden: {}'.format(args.n_hidden))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    gen = net.Generator(n_hidden=args.n_hidden, image_size=args.image_size)
    # dis = Discriminator()
    dis = net.Discriminator2()
    # dis = net.Discriminator3()

    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    if args.dataset == '':
        # Load the CIFAR10 dataset if args.dataset is not specified
        train, _ = chainer.datasets.get_cifar10(withlabel=False, scale=255.)
    else:
        all_files = os.listdir(args.dataset)
        image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
        print('{} contains {} image files'.format(args.dataset,
                                                  len(image_files)))
        train = chainer.datasets.ImageDataset(paths=image_files,
                                              root=args.dataset)

    # train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize,
                                                        n_processes=4)

    def resize_converter(batch, device=None, padding=None):
        new_batch = []
        for image in batch:
            C, W, H = image.shape

            if C == 4:
                image = image[:3, :, :]

            if W < H:
                offset = (H - W) // 2
                image = image[:, :, offset:offset + W]
            elif W > H:
                offset = (W - H) // 2
                image = image[:, offset:offset + H, :]

            image = image.transpose(1, 2, 0)
            image = imresize(image, (args.image_size, args.image_size),
                             interp='bilinear')
            image = image.transpose(2, 0, 1)

            image = image / 255.  # 0. ~ 1.

            # Augumentation... Random vertical flip
            if np.random.rand() < 0.5:
                image = image[:, :, ::-1]

            # Augumentation... Tone correction
            mode = np.random.randint(4)
            # mode == 0 -> no correction
            if mode == 1:
                gain = 0.2 * np.random.rand() + 0.9  # 0.9 ~ 1.1
                image = np.power(image, gain)
            elif mode == 2:
                gain = 1.5 * np.random.rand() + 1e-10  # 0 ~ 1.5
                image = np.tanh(gain * (image - 0.5))

                range_min = np.tanh(gain * (-0.5))  # @x=0.5
                range_max = np.tanh(gain * 0.5)  # @x=1.0
                image = (image - range_min) / (range_max - range_min)
            elif mode == 3:
                gain = 2.0 * np.random.rand() + 1e-10  # 0 ~ 1.5
                image = np.sinh(gain * (image - 0.5))

                range_min = np.tanh(gain * (-0.5))  # @x=0.5
                range_max = np.tanh(gain * 0.5)  # @x=1.0
                image = (image - range_min) / (range_max - range_min)

            image = 2. * image - 1.
            new_batch.append(image.astype(np.float32))
        return concat_examples(new_batch, device=device, padding=padding)

    # Set up a trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=args.gpu,
                           converter=resize_converter)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    display_interval = (args.display_interval, 'iteration')
    preview_interval = (args.preview_interval, 'epoch')
    snapshot_interval = (args.snapshot_interval, 'epoch')

    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)

    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    trainer.extend(extensions.LogReport(trigger=display_interval))

    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'gen/loss',
        'dis/loss',
    ]),
                   trigger=display_interval)

    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(out_generated_image(gen, dis, 10, 10, args.seed, args.out),
                   trigger=preview_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Esempio n. 9
0
def main():
    args = easydict.EasyDict({
        "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data/gan",
        "save_dir": "./",
        "prefix": "feature",
        "workers": 8,
        "batch_size": 128,
        "image_size": 64,
        "nc": 3,
        "nz": 100,
        "ngf": 64,
        "ndf": 64,
        "epochs": 50,
        "lr": 0.0002,
        "beta1": 0.5,
        "gpu": 6,
        "use_cuda": True,
        "feature_matching": True,
        "mini_batch": False
    })

    manualSeed = 999
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    device = torch.device(
        'cuda:{}'.format(args.gpu) if args.use_cuda else 'cpu')

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = dset.ImageFolder(root=args.dataroot, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers)

    # Generator用のモデルのインスタンス作成
    netG = net.Generator(args.nz, args.ngf, args.nc).to(device)
    # Generator用のモデルの初期値を設定
    netG.apply(net.weights_init)

    # Discriminator用のモデルのインスタンス作成
    netD = net.Discriminator(args.nc, args.ndf, device, args.batch_size,
                             args.mini_batch).to(device)
    # Discriminator用のモデルの初期値を設定
    netD.apply(net.weights_init)

    # BCE Loss classのインスタンスを作成
    criterionD = nn.BCELoss()

    if args.feature_matching is True:
        criterionG = nn.MSELoss(reduction='elementwise_mean')
    else:
        criterionG = nn.BCELoss()

    # Generatorに入力するノイズをバッチごとに作成 (バッチ数は64)
    # これはGeneratorの結果を描画するために使用する
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # 最適化関数のインスタンスを作成
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    r = run.NNRun(netD, netG, optimizerD, optimizerG, criterionD, criterionG,
                  device, fixed_noise, args)

    # 学習
    r.train(dataloader)
Esempio n. 10
0
    def partial_test(real_A,
                     mask_aug_A,
                     diff_A,
                     real_B,
                     mask_aug_B,
                     diff_B,
                     real_C,
                     mask_aug_C,
                     diff_C,
                     mask2use,
                     shade_alpha=1):
        G = net.Generator()
        G.load_state_dict(
            torch.load(pwd + '/G.pth', map_location=torch.device('cpu')))
        G.eval()
        with torch.no_grad():
            cur_prama_B = Solver_PSGAN.generate(real_A,
                                                real_B,
                                                None,
                                                None,
                                                mask_aug_A,
                                                mask_aug_B,
                                                diff_A,
                                                diff_B,
                                                ret=True,
                                                generator=G,
                                                mode='test')
            cur_prama_C = Solver_PSGAN.generate(real_A,
                                                real_C,
                                                None,
                                                None,
                                                mask_aug_A,
                                                mask_aug_C,
                                                diff_A,
                                                diff_C,
                                                ret=True,
                                                generator=G,
                                                mode='test')

            cur_prama_source = Solver_PSGAN.generate(real_A,
                                                     real_A,
                                                     None,
                                                     None,
                                                     mask_aug_A,
                                                     mask_aug_A,
                                                     diff_A,
                                                     diff_A,
                                                     ret=True,
                                                     generator=G,
                                                     mode='test')

            partial_gamma = cur_prama_B[0] * mask2use + cur_prama_C[0] * (
                1 - mask2use)
            partial_beta = cur_prama_B[1] * mask2use + cur_prama_C[1] * (
                1 - mask2use)

            partial_gamma = partial_gamma * shade_alpha + cur_prama_source[
                0] * (1 - shade_alpha)
            partial_beta = partial_beta * shade_alpha + cur_prama_source[1] * (
                1 - shade_alpha)

            fake_A = Solver_PSGAN.generate(real_A,
                                           real_B,
                                           None,
                                           None,
                                           mask_aug_A,
                                           mask2use,
                                           diff_A,
                                           diff_B,
                                           gamma=partial_gamma,
                                           beta=partial_beta,
                                           ret=False,
                                           generator=G,
                                           mode='test')
        fake_A = data2img(fake_A)
        return fake_A
Esempio n. 11
0
def main():
    # Set random seem for reproducibility
    manualSeed = 999
    # manualSeed = random.randint(1,10000)
    print("Random Seed: ",manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    # get device
    device = torch.device("cuda:1" if (torch.cuda.is_available() and p.ngpu>0) else "cpu")
    # get network
    netG = net.Generator(p.ngpu).to(device)
    netD = net.Discriminator(p.ngpu).to(device)
    if (device.type=='cuda' and (p.ngpu > 1)):
        netG = nn.DataParallel(netG,list(range(p.ngpu)))
        netD = nn.DataParallel(netD,list(range(p.ngpu)))
    #netG.apply(net.weights_init)
    #netD.apply(net.weights_init)
    print(netG)
    print(netD)

    # Loss function and optimizer
    criterion = nn.BCELoss()

    # create batch of latent vectors what we will use to visualize the progression of the generator
    fixed_noise = torch.randn(p.batch_size,p.nz,1,1,device=device)

    # Establish convention for real and fake labels during training
    real_label = 1
    fake_label = 0

    # setup Adam optimiziers for both G and D
    optimizerG = optim.Adam(netG.parameters(),lr=p.g_lr,betas=(p.beta1,0.999))
    optimizerD = optim.Adam(netD.parameters(),lr=p.d_lr,betas=(p.beta1,0.999))

    # start to train
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    for epoch in range(p.num_epochs):
        for i,data in enumerate(dataset.dataloader,0):
            # (1) Update D network: maximize log(D(x)) + log(1-D(G(z)))
            netD.zero_grad()
            # Format batch
            real = data.to(device)
            label = torch.full((data.size(0),),real_label,device=device)
            # Forward pass real batch through D
            output = netD(real).view(-1) # resize [batch_size,1,1,1] to [batch_size]
            errD_real = criterion(output,label)
            errD_real.backward()
            D_x = output.mean().item()

            #Generate batch of latent vectors
            noise = torch.randn(data.size(0),p.nz,1,1,device=device)
            #Generate batch of fake image
            fake = netG(noise)
            label.fill_(fake_label)
            #Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output,label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake 
            #errD.backward()
            optimizerD.step()

            # (2) Update G network: maximize log(D(G(z)))
            netG.zero_grad()
            label.fill_(real_label) # fake labels are real for generator cost
            output = netD(fake).view(-1)
            errG = criterion(output,label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()                        
            optimizerG.step()

            # output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.8f\tLoss_G: %.8f\tD(x): %.8f\tD(G(z)): %.8f / %.8f'
                  % (epoch, p.num_epochs, i, len(dataset.dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch==p.num_epochs-1) and (i==len(dataset.dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake,padding=2,normalize=True).numpy())
            iters += 1
    np.save(p.result_dir+"img_list.npy",img_list)
    after_train(D_losses,G_losses)
    show_after_train(img_list)
Esempio n. 12
0
def main():
    args = arguments()

    #    chainer.config.type_check = False
    chainer.config.autotune = True
    chainer.config.dtype = dtypes[args.dtype]
    chainer.print_runtime_info()
    #print('Chainer version: ', chainer.__version__)
    #print('GPU availability:', chainer.cuda.available)
    #print('cuDNN availability:', chainer.cuda.cudnn_enabled)

    ## dataset preparation
    if args.imgtype == "dcm":
        from dataset_dicom import Dataset
    else:
        from dataset import Dataset
    train_d = Dataset(args.train,
                      args.root,
                      args.from_col,
                      args.to_col,
                      crop=(args.crop_height, args.crop_width),
                      random=args.random_translate,
                      grey=args.grey)
    test_d = Dataset(args.val,
                     args.root,
                     args.from_col,
                     args.to_col,
                     crop=(args.crop_height, args.crop_width),
                     random=args.random_translate,
                     grey=args.grey)

    # setup training/validation data iterators
    train_iter = chainer.iterators.SerialIterator(train_d, args.batch_size)
    test_iter = chainer.iterators.SerialIterator(test_d,
                                                 args.nvis,
                                                 shuffle=False)
    test_iter_gt = chainer.iterators.SerialIterator(
        train_d, args.nvis,
        shuffle=False)  ## same as training data; used for validation

    args.ch = len(train_d[0][0])
    args.out_ch = len(train_d[0][1])
    print("Input channels {}, Output channels {}".format(args.ch, args.out_ch))

    ## Set up models
    gen = net.Generator(args)
    dis = net.Discriminator(args)

    ## load learnt models
    optimiser_files = []
    if args.model_gen:
        serializers.load_npz(args.model_gen, gen)
        print('model loaded: {}'.format(args.model_gen))
        optimiser_files.append(args.model_gen.replace('gen_', 'opt_gen_'))
    if args.model_dis:
        serializers.load_npz(args.model_dis, dis)
        print('model loaded: {}'.format(args.model_dis))
        optimiser_files.append(args.model_dis.replace('dis_', 'opt_dis_'))

    ## send models to GPU
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    # Setup optimisers
    def make_optimizer(model, lr, opttype='Adam'):
        #        eps = 1e-5 if args.dtype==np.float16 else 1e-8
        optimizer = optim[opttype](lr)
        if args.weight_decay > 0:
            if opttype in ['Adam', 'AdaBound', 'Eve']:
                optimizer.weight_decay_rate = args.weight_decay
            else:
                if args.weight_decay_norm == 'l2':
                    optimizer.add_hook(
                        chainer.optimizer.WeightDecay(args.weight_decay))
                else:
                    optimizer.add_hook(
                        chainer.optimizer_hooks.Lasso(args.weight_decay))
        optimizer.setup(model)
        return optimizer

    opt_gen = make_optimizer(gen, args.learning_rate, args.optimizer)
    opt_dis = make_optimizer(dis, args.learning_rate, args.optimizer)
    optimizers = {'opt_g': opt_gen, 'opt_d': opt_dis}

    ## resume optimisers from file
    if args.load_optimizer:
        for (m, e) in zip(optimiser_files, optimizers):
            if m:
                try:
                    serializers.load_npz(m, optimizers[e])
                    print('optimiser loaded: {}'.format(m))
                except:
                    print("couldn't load {}".format(m))
                    pass

    # Set up trainer
    updater = pixupdater(
        models=(gen, dis),
        iterator={
            'main': train_iter,
            'test': test_iter,
            'test_gt': test_iter_gt
        },
        optimizer={
            'gen': opt_gen,
            'dis': opt_dis
        },
        #        converter=convert.ConcatWithAsyncTransfer(),
        params={'args': args},
        device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    ## save learnt results at an interval
    if args.snapinterval < 0:
        args.snapinterval = args.epoch
    snapshot_interval = (args.snapinterval, 'epoch')
    display_interval = (args.display_interval, 'iteration')

    trainer.extend(extensions.snapshot_object(gen, 'gen_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(opt_gen,
                                              'opt_gen_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    if args.lambda_dis > 0:
        trainer.extend(extensions.snapshot_object(dis,
                                                  'dis_{.updater.epoch}.npz'),
                       trigger=snapshot_interval)
        trainer.extend(
            extensions.dump_graph('dis/loss_real', out_name='dis.dot'))
        trainer.extend(extensions.snapshot_object(
            opt_dis, 'opt_dis_{.updater.epoch}.npz'),
                       trigger=snapshot_interval)

    if args.lambda_rec_l1 > 0:
        trainer.extend(extensions.dump_graph('gen/loss_L1',
                                             out_name='gen.dot'))
    elif args.lambda_rec_l2 > 0:
        trainer.extend(extensions.dump_graph('gen/loss_L2',
                                             out_name='gen.dot'))

    ## log outputs
    log_keys = ['epoch', 'iteration', 'lr']
    log_keys_gen = [
        'gen/loss_L1', 'gen/loss_L2', 'gen/loss_dis', 'myval/loss_L2',
        'gen/loss_tv'
    ]
    log_keys_dis = ['dis/loss_real', 'dis/loss_fake', 'dis/loss_mispair']
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport(log_keys + log_keys_gen +
                                          log_keys_dis),
                   trigger=display_interval)
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(log_keys_gen,
                                  'iteration',
                                  trigger=display_interval,
                                  file_name='loss_gen.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_dis,
                                  'iteration',
                                  trigger=display_interval,
                                  file_name='loss_dis.png'))
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.ParameterStatistics(gen))
    # learning rate scheduling
    if args.optimizer in ['SGD', 'Momentum', 'AdaGrad', 'RMSprop']:
        trainer.extend(extensions.observe_lr(optimizer_name='gen'),
                       trigger=display_interval)
        trainer.extend(extensions.ExponentialShift('lr',
                                                   0.33,
                                                   optimizer=opt_gen),
                       trigger=(args.epoch / 5, 'epoch'))
        trainer.extend(extensions.ExponentialShift('lr',
                                                   0.33,
                                                   optimizer=opt_dis),
                       trigger=(args.epoch / 5, 'epoch'))
    elif args.optimizer in ['Adam', 'AdaBound', 'Eve']:
        trainer.extend(extensions.observe_lr(optimizer_name='gen'),
                       trigger=display_interval)
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.33,
                                                   optimizer=opt_gen),
                       trigger=(args.epoch / 5, 'epoch'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.33,
                                                   optimizer=opt_dis),
                       trigger=(args.epoch / 5, 'epoch'))

    # evaluation
    vis_folder = os.path.join(args.out, "vis")
    os.makedirs(vis_folder, exist_ok=True)
    if not args.vis_freq:
        args.vis_freq = len(train_d) // 2
    trainer.extend(VisEvaluator({
        "test": test_iter,
        "train": test_iter_gt
    }, {"gen": gen},
                                params={'vis_out': vis_folder},
                                device=args.gpu),
                   trigger=(args.vis_freq, 'iteration'))

    # ChainerUI: removed until ChainerUI updates to be compatible with Chainer 6.0
    #    trainer.extend(CommandsExtension())

    # Run the training
    print("trainer start")
    trainer.run()
Esempio n. 13
0
    dataset = Dataset(path=args.root,
                      args=args,
                      base=args.HU_baseA,
                      rang=args.HU_rangeA,
                      random=0)
    args.ch = dataset.ch
    #    iterator = chainer.iterators.MultiprocessIterator(dataset, args.batch_size, n_processes=3, repeat=False, shuffle=False)
    iterator = chainer.iterators.MultithreadIterator(
        dataset, args.batch_size, n_threads=3, repeat=False,
        shuffle=False)  ## best performance
    #    iterator = chainer.iterators.SerialIterator(dataset, args.batch_size,repeat=False, shuffle=False)

    ## load generator models
    if "gen" in args.load_models:
        gen = net.Generator(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, gen)
        if args.gpu >= 0:
            gen.to_gpu()
        xp = gen.xp
        is_AE = False
    elif "enc" in args.load_models:
        enc = net.Encoder(args)
        print('Loading {:s}..'.format(args.load_models))
        serializers.load_npz(args.load_models, enc)
        dec = net.Decoder(args)
        modelfn = args.load_models.replace('enc_x', 'dec_y')
        modelfn = modelfn.replace('enc_y', 'dec_x')
        print('Loading {:s}..'.format(modelfn))
        serializers.load_npz(modelfn, dec)
Esempio n. 14
0
def gan_training(args, train, test):
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    if args.loaderjob:
        train_iter = chainer.iterators.MultiprocessIterator(
            train, args.batchsize, n_processes=args.loaderjob)
    else:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Prepare Texture GAN model, defined in net.py
    gen = net.Generator(args.dimz)
    dis = net.Discriminator()

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    xp = np if args.gpu < 0 else cuda.cupy

    opt_gen = make_optimizer(gen, args, alpha=args.alpha)
    opt_dis = make_optimizer(dis, args, alpha=args.alpha)

    # Updater
    updater = GAN_Updater(models=(gen, dis),
                          iterator=train_iter,
                          optimizer={
                              'gen': opt_gen,
                              'dis': opt_dis
                          },
                          device=args.gpu)

    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    snapshot_interval = (args.snapshot_interval), 'iteration'
    visualize_interval = (args.visualize_interval), 'iteration'
    log_interval = (args.log_interval), 'iteration'

    # Be careful to pass the interval directly to LogReport
    # (it determines when to emit log rather than when to read observations)

    trainer.extend(extensions.LogReport(trigger=log_interval))

    trainer.extend(
        extensions.PlotReport(['gen/loss', 'dis/loss'],
                              trigger=log_interval,
                              file_name='plot.png'))

    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'gen/loss', 'dis/loss']),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(extensions.snapshot(), trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    np.random.seed(0)
    train_indices = np.random.randint(0, len(train),
                                      args.rows * args.cols).tolist()
    test_indices = np.random.randint(0, len(test),
                                     args.rows * args.cols).tolist()
    np.random.seed()
    train_indices[-2] = len(train) - 3
    train_indices[-3] = len(train) - 1

    trainer.extend(visualizer.extension(train,
                                        gen,
                                        train_indices,
                                        args,
                                        'train',
                                        rows=args.rows,
                                        cols=args.cols),
                   trigger=visualize_interval)
    trainer.extend(visualizer.extension(test,
                                        gen,
                                        test_indices,
                                        args,
                                        'test',
                                        rows=args.rows,
                                        cols=args.cols),
                   trigger=visualize_interval)

    if args.adam_decay_iteration:
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_gen),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_dis),
                       trigger=(args.adam_decay_iteration, 'iteration'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Esempio n. 15
0
def gan_test(args, model_path):
    # Prepare Flow and Texture GAN model, defined in net.py

    gen_flow = net.FlowGenerator()
    serializers.load_npz(model_path["gen_flow"], gen_flow)
    gen_tex = net.Generator(dimz=100)
    serializers.load_npz(model_path["gen_tex"], gen_tex)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen_flow.to_gpu()
        gen_tex.to_gpu()
    xp = np if args.gpu < 0 else cuda.cupy

    rows = 5
    cols = 5

    ### generate videos from Z
    np.random.seed(0)
    for i in range(10):
        print(i)
        z_flow = Variable(xp.asarray(gen_flow.make_hidden(rows * cols)))
        z_tex = Variable(xp.asarray(gen_tex.make_hidden(rows * cols)))

        ### generate flow
        with chainer.using_config('train', False):
            flow_fake, _, _ = gen_flow(z_flow)
        flow_fake_tmp = chainer.cuda.to_cpu(flow_fake.data)

        ### generate video
        with chainer.using_config('train', False):
            y, fore_vid, back_img, h_mask = gen_tex(z_tex, flow_fake)
        y = chainer.cuda.to_cpu(y.data)
        fore_vid = chainer.cuda.to_cpu(fore_vid.data)
        back_img = chainer.cuda.to_cpu(back_img.data)
        y_mask = chainer.cuda.to_cpu(h_mask.data)
        flow = flow_fake_tmp

        preview_dir = '{}/{:03}/'.format(args.out, i)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)

        ## save video
        y = np.asarray(np.clip((y + 1.) * (255. / 2.), 0.0, 255.0),
                       dtype=np.uint8)
        B, CH, T, H, W = y.shape
        Y = y.reshape((rows, cols, CH, T, H, W))
        Y = Y.transpose(3, 0, 4, 1, 5, 2)  ### T, rows, H, cols, W, ch
        Y = Y.reshape((T, rows * H, cols * W, CH))  # T, H, W, ch
        for j in range(0, T):
            preview_path = preview_dir + 'img_{:03}.jpg'.format(j + 1)
            Image.fromarray(Y[j]).save(preview_path)
        # images = []
        # for filename in filenames:
        #     images.append(imageio.imread(filename))
        imageio.mimsave(preview_dir + 'movie_{:03}.gif'.format(i), Y)

        ### save fore video
        y = np.asarray(np.clip((fore_vid + 1.) * (255. / 2.), 0.0, 255.0),
                       dtype=np.uint8)
        B, CH, T, H, W = y.shape
        Y = y.reshape((rows, cols, CH, T, H, W))
        Y = Y.transpose(3, 0, 4, 1, 5, 2)  ### T, rows, H, cols, W, ch
        Y = Y.reshape((T, rows * H, cols * W, CH))  # T, H, W, ch
        for j in range(0, T):
            preview_path = preview_dir + 'fore_{:03}.jpg'.format(j + 1)
            Image.fromarray(Y[j]).save(preview_path)

        ### save mask video
        y = np.asarray(np.clip(y_mask * 255., 0.0, 255.0), dtype=np.uint8)
        B, CH, T, H, W = y.shape
        Y = y.reshape((rows, cols, CH, T, H, W))
        Y = Y.transpose(3, 0, 4, 1, 5, 2)  ### T, rows, H, cols, W, ch
        Y = Y.reshape((T, rows * H, cols * W, CH))  # T, H, W, ch
        for j in range(0, T):
            preview_path = preview_dir + 'mask_{:03}.jpg'.format(j + 1)
            Image.fromarray(Y[j]).save(preview_path)

        ### save back img
        y = np.asarray(np.clip((back_img + 1.) * (255. / 2.), 0.0, 255.0),
                       dtype=np.uint8)
        B, CH, T, H, W = y.shape
        y = y[:, :, 0]
        Y = y.reshape((rows, cols, CH, H, W))
        Y = Y.transpose(0, 3, 1, 4, 2)  ### rows, H, cols, W, ch
        Y = Y.reshape((rows * H, cols * W, CH))  # T, H, W, ch
        preview_path = preview_dir + 'back.jpg'
        Image.fromarray(Y).save(preview_path)

        ### save flow
        y = np.asarray(np.clip((flow + 1.) * (255. / 2.), 0.0, 255.0),
                       dtype=np.uint8)
        B, CH, T, H, W = y.shape
        Y = y.reshape((rows, cols, CH, T, H, W))
        Y = Y.transpose(3, 0, 4, 1, 5, 2)  ### T, rows, H, cols, W, ch
        Y = Y.reshape((T, rows * H, cols * W, CH))  # T, H, W, ch

        for j in range(0, T):
            preview_path = preview_dir + 'flow_{:03}.jpg'.format(j + 1)
            flow_img = np.hstack((Y[j, :, :, 0], Y[j, :, :, 1]))
            Image.fromarray(flow_img).save(preview_path)
Esempio n. 16
0
    [0x2464, 0x2466],
    range(0x2468, 0x246e),
    [0x246f, 0x2472, 0x2473],
    _flatten(
        map(lambda x: range(x + 0x21, x + 0x7f), range(0x3000, 0x4f00,
                                                       0x100))),
    range(0x4f21, 0x4f54),
])

codes = map(jiscode_to_unicode, jis_codes)
code_to_index = dict(zip(codes, range(len(codes))))

if image_size == 48:
    gen = net.Generator48()
else:
    gen = net.Generator()
serializers.load_hdf5(args.model, gen)

gpu_device = None
if args.gpu >= 0:
    cuda.check_cuda_available()
    gpu_device = args.gpu
    gen.to_gpu(gpu_device)
    xp = cuda.cupy
else:
    xp = np

latent_size = gen.latent_size
characters = map(lambda x: code_to_index[ord(x)],
                 args.text.decode(args.character_code))
y = Variable(xp.asarray(characters).astype(np.int32), volatile=True)
Esempio n. 17
0
def gan_training(args, train, test, model_path):
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    if args.loaderjob:
        train_iter = chainer.iterators.MultiprocessIterator(
            train, args.batchsize, n_processes=args.loaderjob)
    else:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Prepare Flow and Texture GAN model, defined in net.py
    gen_flow = net.FlowGenerator()
    dis_flow = net.FlowDiscriminator()
    gen_tex = net.Generator(args.dimz)
    dis_tex = net.Discriminator()

    serializers.load_npz(model_path["gen_flow"], gen_flow)
    serializers.load_npz(model_path["dis_flow"], dis_flow)
    serializers.load_npz(model_path["gen_tex"], gen_tex)
    serializers.load_npz(model_path["dis_tex"], dis_tex)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen_flow.to_gpu()
        dis_flow.to_gpu()
        gen_tex.to_gpu()
        dis_tex.to_gpu()

    xp = np if args.gpu < 0 else cuda.cupy

    opt_flow_gen = make_optimizer(gen_flow, args, alpha=1e-7)
    opt_flow_dis = make_optimizer(dis_flow, args, alpha=1e-7)
    opt_tex_gen = make_optimizer(gen_tex, args, alpha=1e-6)
    opt_tex_dis = make_optimizer(dis_tex, args, alpha=1e-6)

    # Updater
    updater = GAN_Updater(models=(gen_flow, dis_flow, gen_tex, dis_tex),
                          iterator=train_iter,
                          optimizer={
                              'gen_flow': opt_flow_gen,
                              'dis_flow': opt_flow_dis,
                              'gen_tex': opt_tex_gen,
                              'dis_tex': opt_tex_dis
                          },
                          device=args.gpu,
                          C=args.C)

    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    snapshot_interval = (args.snapshot_interval), 'iteration'
    visualize_interval = (args.visualize_interval), 'iteration'
    log_interval = (args.log_interval), 'iteration'

    # Be careful to pass the interval directly to LogReport
    # (it determines when to emit log rather than when to read observations)

    trainer.extend(extensions.LogReport(trigger=log_interval))

    trainer.extend(
        extensions.PlotReport(
            ['gen_flow/loss', 'dis_flow/loss', 'gen_tex/loss', 'dis_tex/loss'],
            trigger=log_interval,
            file_name='plot.png'))

    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'gen_flow/loss', 'dis_flow/loss', 'gen_tex/loss',
        'dis_tex/loss'
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(extensions.snapshot(), trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen_flow, 'gen_flow_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis_flow, 'dis_flow_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen_tex, 'gen_tex_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis_tex, 'dis_tex_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    trainer.extend(visualizer.extension(test, (gen_flow, gen_tex),
                                        args,
                                        rows=args.rows,
                                        cols=args.cols),
                   trigger=visualize_interval)

    if args.adam_decay_iteration:
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_flow_gen),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_flow_dis),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_tex_gen),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_tex_dis),
                       trigger=(args.adam_decay_iteration, 'iteration'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Esempio n. 18
0
def main():
    parser = argparse.ArgumentParser(description='WGAN(GP) MNIST')
    parser.add_argument('--mode', '-m', type=str, default='WGAN',
                    help='WGAN or WGANGP')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                    help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch', '-e', type=int, default=100,
                    help='number of epochs to learn')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='learning minibatch size')
    parser.add_argument('--optimizer', type=str, default='Adam',
                        help='optimizer')
    parser.add_argument('--out', '-o', type=str, default='model',
                        help='path to the output directory')
    parser.add_argument('--dimz', '-z', type=int, default=20,
                        help='dimention of encoded vector')
    parser.add_argument('--n_dis', type=int, default=5,
                        help='dimention of encoded vector')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapepoch', '-s', type=int, default=10,
                        help='number of epochs to snapshot')
    parser.add_argument('--load_gen_model', type=str, default='',
                        help='load generator model')
    parser.add_argument('--load_dis_model', type=str, default='',
                        help='load generator model')
    args = parser.parse_args()

    if not os.path.exists(args.out):
        os.makedirs(args.out)

    print(args)


    gen = net.Generator(784, args.dimz, 500)
    dis = net.Discriminator(784, 500)

    if args.load_gen_model != '':
        chainer.serializers.load_npz(args.load_gen_model, gen)
    if args.load_dis_model != '':
        chainer.serializers.load_npz(args.load_dis_model, dis)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()
        print('GPU {}'.format(args.gpu))
    xp = np if args.gpu < 0 else cuda.cupy

    if args.optimizer == 'Adam':
        opt_gen = chainer.optimizers.Adam(alpha=0.0001, beta1=0.5, beta2=0.9)
        opt_dis = chainer.optimizers.Adam(alpha=0.0001, beta1=0.5, beta2=0.9)
        opt_gen.setup(gen)
        opt_dis.setup(dis)
        opt_dis.add_hook(WeightClipping(0.01))
    elif args.optimizer == 'RMSprop':
        opt_gen = chainer.optimizers.RMSprop(5e-5)
        opt_dis = chainer.optimizers.RMSprop(5e-5)
        opt_gen.setup(gen)
        opt_gen.add_hook(chainer.optimizer.GradientClipping(1))
        opt_dis.setup(dis)
        opt_dis.add_hook(chainer.optimizer.GradientClipping(1))
        opt_dis.add_hook(WeightClipping(0.01))

    train, _ = chainer.datasets.get_mnist(withlabel=False)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=True)

    if args.mode == 'WGAN':
        updater = WGANUpdater(
            models=(gen, dis),
            iterators={
                'main': train_iter
            },
            optimizers={
                'gen': opt_gen,
                'dis': opt_dis
            },
            device=args.gpu,
            params={
                'batchsize': args.batchsize,
                'n_dis': args.n_dis
            })
    elif args.mode == 'WGANGP':
        updater = WGANGPUpdater(
            models=(gen, dis),
            iterators={
                'main': train_iter
            },
            optimizers={
                'gen': opt_gen,
                'dis': opt_dis
            },
            device=args.gpu,
            params={
                'batchsize': args.batchsize,
                'n_dis': args.n_dis,
                'lam': 10
            })

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.dump_graph('was. dist'))

    snapshot_interval = (args.snapepoch, 'epoch')
    trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}.npz'),
        trigger=snapshot_interval)

    trainer.extend(extensions.PlotReport(['loss/gen'], 'epoch', file_name='generator.png'))

    if args.mode == 'WGAN':
        log_keys = ['epoch', 'was. dist', 'gen/loss', 'dis/loss']
    elif args.mode == 'WGANGP':
        log_keys = ['epoch', 'was. dist', 'grad. pen', 'gen/loss']
    trainer.extend(extensions.LogReport(keys=log_keys))
    trainer.extend(extensions.PrintReport(log_keys))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(out_generated_image(gen, 20, 20, args.seed, args.out),
        trigger=(1, 'epoch'))

    trainer.run()
Esempio n. 19
0
    kwargs = {'batch_size': BATCH_SIZE, 'num_workers': 1, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(root=ROOT_DIR_PATH, train=True, transform=transform, download=True),
        shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(root=ROOT_DIR_PATH, train=False, transform=transform),
        shuffle=False, **kwargs)

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    # p(x|y,z)
    p = net.Generator().to(device)

    # q(z|x,y)
    q = net.Inference().to(device)

    # prior p(z)
    prior = Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0),
                   var=["z"], features_shape=[net.Z_DIM], name="p_{prior}").to(device)

    loss = (KullbackLeibler(q, prior) - Expectation(q, LogProb(p))).mean()
    model = Model(loss=loss, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3})
    # print(model)

    x_fixed, y_fixed = next(iter(test_loader))
    x_fixed = x_fixed[:8].to(device)
    y_fixed = y_fixed[:8]
Esempio n. 20
0
def gan_training(args, train):
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    if args.loaderjob:
        train_iter = chainer.iterators.MultiprocessIterator(
            train, args.batchsize, n_processes=args.loaderjob)
    else:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Prepare Flow GAN model, defined in net.py
    gen = net.Generator(video_len=args.video_len)
    dis = net.Discriminator()

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()
    xp = np if args.gpu < 0 else cuda.cupy

    opt_gen = make_optimizer(gen, args)
    opt_dis = make_optimizer(dis, args)

    # Updater
    updater = GAN_Updater(models=(gen, dis),
                          iterator=train_iter,
                          optimizer={
                              'gen': opt_gen,
                              'dis': opt_dis
                          },
                          device=args.gpu)

    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    snapshot_interval = (args.snapshot_interval), 'iteration'
    visualize_interval = (args.visualize_interval), 'iteration'
    log_interval = (args.log_interval), 'iteration'

    trainer.extend(extensions.LogReport(trigger=log_interval))

    trainer.extend(
        extensions.PlotReport(['gen/loss', 'dis/loss'],
                              trigger=log_interval,
                              file_name='plot.png'))

    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'gen/loss', 'dis/loss']),
                   trigger=log_interval)

    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(extensions.snapshot(), trigger=snapshot_interval)

    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    trainer.extend(extension(gen, args), trigger=visualize_interval)

    if args.adam_decay_iteration:
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_gen),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_dis),
                       trigger=(args.adam_decay_iteration, 'iteration'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Esempio n. 21
0
def main():
    # This enables a ctr-C without triggering errors
    import signal
    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    parser = argparse.ArgumentParser(description='GAN practice on MNIST')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=100,
                        help='number of epochs to learn')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--out',
                        '-o',
                        type=str,
                        default='model',
                        help='path to the output directory')
    parser.add_argument('--dimz',
                        '-z',
                        type=int,
                        default=20,
                        help='dimention of encoded vector')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapepoch',
                        '-s',
                        type=int,
                        default=10,
                        help='number of epochs to snapshot')
    parser.add_argument('--load_gen_model',
                        type=str,
                        default='',
                        help='load generator model')
    parser.add_argument('--load_dis_model',
                        type=str,
                        default='',
                        help='load generator model')
    args = parser.parse_args()

    if not os.path.exists(args.out):
        os.makedirs(args.out)

    print(args)

    gen = net.Generator(784, args.dimz, 500)
    dis = net.Discriminator(784, 500)

    if args.load_gen_model != '':
        chainer.serializers.load_npz(args.load_gen_model, gen)
        print('Generator model loaded successfully!')
    if args.load_dis_model != '':
        chainer.serializers.load_npz(args.load_dis_model, dis)
        print('Discriminator model loaded successfully!')

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()
        print('GPU {}'.format(args.gpu))
    xp = np if args.gpu < 0 else cuda.cupy

    opt_gen = chainer.optimizers.Adam()
    opt_dis = chainer.optimizers.Adam()
    opt_gen.setup(gen)
    opt_dis.setup(dis)

    dataset = MnistDataset('./data')
    # train, val = chainer.datasets.split_dataset_random(dataset, int(len(dataset) * 0.9))

    train_iter = chainer.iterators.SerialIterator(dataset,
                                                  args.batchsize,
                                                  shuffle=True)
    # val_iter = chainer.iterators.SerialIterator(val, args.batchsize, repeat=False, shuffle=False)

    updater = GANUpdater(models=(gen, dis),
                         iterators={'main': train_iter},
                         optimizers={
                             'gen': opt_gen,
                             'dis': opt_dis
                         },
                         device=args.gpu,
                         params={
                             'batchsize': args.batchsize,
                             'n_latent': args.dimz
                         })
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapepoch, 'epoch')
    display_interval = (100, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(gen, 'gen{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(dis, 'dis{.updater.epoch}.npz'),
                   trigger=snapshot_interval)

    log_keys = ['epoch', 'iteration', 'gen/loss', 'dis/loss']
    trainer.extend(
        extensions.LogReport(keys=log_keys, trigger=display_interval))
    trainer.extend(extensions.PrintReport(log_keys), trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(out_generated_image(gen, 10, 10, args.seed, args.out),
                   trigger=(1, 'epoch'))

    trainer.run()
Esempio n. 22
0
def main():
    parser = argparse.ArgumentParser(description='Train CAGAN')
    # common
    parser.add_argument('--n_iter', default=10000, type=int,
                        help='number of update')
    parser.add_argument('--lr', default=0.0002, type=float,
                        help='learning rate of Adam')
    parser.add_argument('--cuda', default=0,
                        help='0 indicates CPU')
    parser.add_argument('--out', default='results',
                        help='log file')
    parser.add_argument('--log_interval', default=20, type=int,
                        help='log inteval')
    parser.add_argument('--ckpt_interval', default=50, type=int,
                        help='save interval')
    parser.add_argument('--seed', default=42)
    # model
    parser.add_argument('--deconv', default='upconv',
                        help='deconv or upconv')
    parser.add_argument('--relu', default='relu',
                        help='relu or relu6')
    parser.add_argument('--bias', default=True,
                        help='use bias or not')
    parser.add_argument('--init', default=None,
                        help='weight initialization')
    # loss
    parser.add_argument('--gamma_cycle', default=1.0, type=float,
                        help='coefficient for cycle loss')
    parser.add_argument('--gamma_id', default=1.0, type=float,
                        help='coefficient for mask')
    parser.add_argument('--norm', default=1, type=int,
                        help='selct norm type. Default is 1')
    parser.add_argument('--cycle_norm', default=2, type=int,
                        help='selct norm type. Default is 2')
    # dataset
    parser.add_argument('--root', default='data',
                        help='root directory')
    parser.add_argument('--base_root', default='images',
                        help='root directory to images')
    parser.add_argument('--triplet', default='triplet.json',
                        help='triplet list')
    args = parser.parse_args()
    time = dt.now().strftime('%m%d_%H%M')
    print('+++++ begin at {} +++++'.format(time))
    for key, value in dict(args):
        print('### {}: {}'.format(key, value))
    print('+++++')

    if args.seed is not None:
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)

    out = os.path.join(args.out, time)
    if not os.path.isdir(out):
        os.makedirs(out)
    log_dir = os.path.join(out, 'tensorboard')
    os.makedirs(log_dir)
    ckpt_dir = os.path.join(out, 'checkpoints')
    os.makedirs(ckpt_dir)
    args.out = out
    args.log_dir = log_dir
    args.ckpt_dir = ckpt_dir

    # logger
    logger = getLogger()
    f_h = FileHandler(os.path.join(out, 'log.txt'), 'a', 'utf-8')
    f_h.setLevel(DEBUG)
    s_h = StreamHandler()
    s_h.setLevel(DEBUG)
    logger.setLevel(DEBUG)
    logger.addHandler(f_h)
    logger.addHandler(s_h)

    # tensorboard
    writer = SummaryWriter(log_dir)

    logger.debug("=====")
    logger.debug(json.dumps(args.__dict__, indent=2))
    logger.debug("=====")

    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset.TripletDataset(
            args.root,
            os.path.join(args.root, args.triplet),
            transform=transforms.Compose([
                transforms.Scale((132, 100)),
                transforms.RandomCrop((128, 96)),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.234, 0.225]),
                transforms.ToTensor()])),
        batch_size=args.batch_size, shuffle=True, **kwargs)

    generator = net.Generator(args.deconv, args.relu, args.bias, args.norm)
    discriminator = net.Discriminator(args.relu, args.bias)
    opt_G = torch.optim.Adam(generator.parameters(), lr=args.lr)
    opt_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr)
    if args.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    start = dt.now()
    logger.debug('===start {}'.format(start.strftime("%m/%d, %H:%M")))
    for iter_ in range(args.n_iter):
        # register params to tensorboard
        if args.cuda:
            generator = generator.cpu()
            discriminator = discriminator.cpu()
        for name, params in generator.named_children():
            if params.requires_grad():
                writer.add_histogram(tag='generator/' + name,
                                     values=params.data.numpy(),
                                     global_step=iter_)
        for name, params in discriminator.named_children():
            if params.requires_grad():
                writer.add_histogram(tag='discriminator/' + name,
                                     values=params.data.numpy(),
                                     global_step=iter_)
        if args.cuda:
            generator = generator.cuda()
            discriminator = discriminator.cuda()

        list_of_tensor = train_loader.next()
        # y_1: exchanged, y_2: cycle
        loss_d, loss_g, mask_norm, cycle_loss, y_1, y_2 = train(
            iter_, generator, discriminator, opt_G, opt_D, list_of_tensor,
            args.cycle_norm, args.gamma_cycle, args.gamma_id, args.cuda)

        writer.add_scalars(
            main_tag='training',
            tag_scalar_dict={
                'discriminator/loss': loss_d,
                'generator/loss': loss_d,
                'generator/mask_norm': mask_norm,
                'generator/cycle_loss': cycle_loss
            }, global_step=iter_)

        if iter_ % args.log_interval == 0:
            msg = "Iter {} \tDis loss: {:.5f}\tGen loss: {:.5f}\tNorm: {:.5f}\n"
            logger.debug(
                msg.format(iter_, loss_d, loss_g, mask_norm + cycle_loss))

        if iter_ % args.ckpt_interval == 0:
            # save checkpoint and images used in this iteration
            if args.cuda:
                generator = generator.cpu()
                discriminator = discriminator.cpu()
            generator.eval()
            discriminator.eval()
            torch.save({
                'iteration': iter_,
                'gen_state': generator.state_dict(),
                'dis_state': discriminator.state_dict(),
                'opt_gen': opt_G.state_dict(),
                'opt_dis': opt_D.state_dict()},
                os.path.join(ckpt_dir, 'ckpt_iter_{}.pth'.format(iter_)))

            # save images
            if args.cuda:
                y_1 = y_1.cpu()
                y_2 = y_2.cpu()
            writer.add_image(tag='input/human',
                             image_tensor=list_of_tensor[0],
                             global_step=iter_)
            writer.add_image(tag='input/item',
                             image_tensor=list_of_tensor[1],
                             global_step=iter_)
            writer.add_image(tag='input/want'
                             image_tensor=list_of_tensor[2],
                             global_step=iter_)
            writer.add_image(tag='output/changed',
                             image_tensor=y_1.data,
                             global_step=iter_)
            writer.add_image(tag='output/cycle',
                             image_tensor=y_2.data,
                             global_step=iter_)
            if args.cuda:
                generator = generator.cuda()
                discriminator = discriminator.cuda()

    # save checkpoint
    torch.save({
        'iteration': args.n_iter,
        'gen_state': generator.state_dict(),
        'dis_state': discriminator.state_dict(),
        'opt_gen': opt_G.state_dict(),
        'opt_dis': opt_D.state_dict()},
        os.path.join(ckpt_dir, 'ckpt_iter_{}.pth'.format(args.n_iter)))

    writer.export_scalars_to_json(os.path.join(log_dir + 'scalars.json'))
    writer.close()
    end = dt.now()
    logger.debug('===end {}, {}[min]'.format(end.strftime(
        '%m/%d, %H:%M'), (end - start).total_seconds() / 60.))
Esempio n. 23
0
        with open(os.path.join(args.out,"filenames.txt"),'w') as output:
            for file in glob.glob(os.path.join(args.root,"**/*.{}".format(args.imgtype)), recursive=True):
                output.write('{}\n'.format(file))
        dataset = Dataset(os.path.join(args.out,"filenames.txt"), "", [0], [0], crop=(args.crop_height,args.crop_width), random=False, grey=args.grey)
        
#    iterator = chainer.iterators.MultiprocessIterator(dataset, args.batch_size, n_processes=3, repeat=False, shuffle=False)
    iterator = chainer.iterators.MultithreadIterator(dataset, args.batch_size, n_threads=3, repeat=False, shuffle=False)   ## best performance
#    iterator = chainer.iterators.SerialIterator(dataset, args.batch_size,repeat=False, shuffle=False)

    args.ch = len(dataset[0][0])
    args.out_ch = len(dataset[0][1])
    print("Input channels {}, Output channels {}".format(args.ch,args.out_ch))

    ## load generator models
    if args.model_gen:
            gen = net.Generator(args)
            print('Loading {:s}..'.format(args.model_gen))
            serializers.load_npz(args.model_gen, gen)
            if args.gpu >= 0:
                gen.to_gpu()
            xp = gen.xp
    else:
        print("Specify a learnt model.")
        exit()        

    ## start measuring timing
    os.makedirs(outdir, exist_ok=True)
    start = time.time()

    cnt = 0
    salt = str(random.randint(1000, 999999))
Esempio n. 24
0
def main():
    args = arguments()
    out = os.path.join(args.out, dt.now().strftime('%m%d_%H%M'))
    print(args)
    print(out)
    save_args(args, out)
    args.dtype = dtypes[args.dtype]
    args.dis_activation = activation[args.dis_activation]
    args.gen_activation = activation[args.gen_activation]
    args.gen_out_activation = activation[args.gen_out_activation]

    if args.imgtype == "dcm":
        from dataset_dicom import DatasetOutMem as Dataset
    else:
        from dataset_jpg import DatasetOutMem as Dataset

    if not chainer.cuda.available:
        print("CUDA required")
        exit()

    if len(args.gpu) == 1 and args.gpu[0] >= 0:
        chainer.cuda.get_device_from_id(args.gpu[0]).use()

    # Enable autotuner of cuDNN
    chainer.config.autotune = True
    chainer.config.dtype = args.dtype
    chainer.print_runtime_info()
    #    print('Chainer version: ', chainer.__version__)
    #    print('GPU availability:', chainer.cuda.available)
    #    print('cuDNN availability:', chainer.cuda.cudnn_enabled)

    ## dataset iterator
    print("Setting up data iterators...")
    train_A_dataset = Dataset(path=os.path.join(args.root, 'trainA'),
                              baseA=args.HU_base,
                              rangeA=args.HU_range,
                              slice_range=args.slice_range,
                              crop=(args.crop_height, args.crop_width),
                              random=args.random_translate,
                              forceSpacing=0,
                              imgtype=args.imgtype,
                              dtype=args.dtype)
    train_B_dataset = Dataset(path=os.path.join(args.root, 'trainB'),
                              baseA=args.HU_base,
                              rangeA=args.HU_range,
                              slice_range=args.slice_range,
                              crop=(args.crop_height, args.crop_width),
                              random=args.random_translate,
                              forceSpacing=args.forceSpacing,
                              imgtype=args.imgtype,
                              dtype=args.dtype)
    test_A_dataset = Dataset(path=os.path.join(args.root, 'testA'),
                             baseA=args.HU_base,
                             rangeA=args.HU_range,
                             slice_range=args.slice_range,
                             crop=(args.crop_height, args.crop_width),
                             random=0,
                             forceSpacing=0,
                             imgtype=args.imgtype,
                             dtype=args.dtype)
    test_B_dataset = Dataset(path=os.path.join(args.root, 'testB'),
                             baseA=args.HU_base,
                             rangeA=args.HU_range,
                             slice_range=args.slice_range,
                             crop=(args.crop_height, args.crop_width),
                             random=0,
                             forceSpacing=args.forceSpacing,
                             imgtype=args.imgtype,
                             dtype=args.dtype)

    args.ch = train_A_dataset.ch
    test_A_iter = chainer.iterators.SerialIterator(test_A_dataset,
                                                   args.nvis_A,
                                                   shuffle=False)
    test_B_iter = chainer.iterators.SerialIterator(test_B_dataset,
                                                   args.nvis_B,
                                                   shuffle=False)

    if args.batch_size > 1:
        train_A_iter = chainer.iterators.MultiprocessIterator(
            train_A_dataset,
            args.batch_size,
            n_processes=3,
            shuffle=not args.conditional_discriminator)
        train_B_iter = chainer.iterators.MultiprocessIterator(
            train_B_dataset,
            args.batch_size,
            n_processes=3,
            shuffle=not args.conditional_discriminator)
    else:
        train_A_iter = chainer.iterators.SerialIterator(
            train_A_dataset,
            args.batch_size,
            shuffle=not args.conditional_discriminator)
        train_B_iter = chainer.iterators.SerialIterator(
            train_B_dataset,
            args.batch_size,
            shuffle=not args.conditional_discriminator)

    # setup models
    gen_g = net.Generator(args)
    gen_f = net.Generator(args)
    dis_y = net.Discriminator(args)
    dis_x = net.Discriminator(args)
    models = {'gen_g': gen_g, 'gen_f': gen_f, 'dis_x': dis_x, 'dis_y': dis_y}

    ## load learnt models
    optimiser_files = []
    if args.load_models:
        for e in models:
            m = args.load_models.replace('gen_g', e)
            try:
                serializers.load_npz(m, models[e])
                print('model loaded: {}'.format(m))
            except:
                print("couldn't load {}".format(m))
                pass
            optimiser_files.append(m.replace(e, 'opt_' + e[-1]))

    # select GPU
    if len(args.gpu) == 1:
        for e in models:
            models[e].to_gpu()
    else:
        print("mandatory GPU use: currently only a single GPU can be used")
        exit()

    # Setup optimisers
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        eps = 1e-5 if args.dtype == np.float16 else 1e-8
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1, eps=eps)
        optimizer.setup(model)
        if args.weight_decay > 0:
            if args.weight_decay_norm == 'l2':
                optimizer.add_hook(
                    chainer.optimizer.WeightDecay(args.weight_decay))
            else:
                optimizer.add_hook(
                    chainer.optimizer_hooks.Lasso(args.weight_decay))
        return optimizer

    opt_g = make_optimizer(gen_g, alpha=args.learning_rate_g)
    opt_f = make_optimizer(gen_f, alpha=args.learning_rate_g)
    opt_y = make_optimizer(dis_y, alpha=args.learning_rate_d)
    opt_x = make_optimizer(dis_x, alpha=args.learning_rate_d)
    #    opt_g.add_hook(chainer.optimizer_hooks.GradientClipping(5))
    optimizers = {
        'opt_g': opt_g,
        'opt_f': opt_f,
        'opt_x': opt_x,
        'opt_y': opt_y
    }
    if args.load_optimizer:
        for (m, e) in zip(optimiser_files, optimizers):
            if m:
                try:
                    serializers.load_npz(m, optimizers[e])
                    print('optimiser loaded: {}'.format(m))
                except:
                    print("couldn't load {}".format(m))
                    pass

    # Set up an updater
    print("Preparing updater...")
    updater = Updater(models=(gen_g, gen_f, dis_x, dis_y),
                      iterator={
                          'main': train_A_iter,
                          'train_B': train_B_iter,
                      },
                      optimizer=optimizers,
                      device=args.gpu[0],
                      converter=convert.ConcatWithAsyncTransfer(),
                      params={'args': args})

    if not args.snapinterval:
        args.snapinterval = (args.lrdecay_start + args.lrdecay_start) // 5
    log_interval = (200, 'iteration')
    model_save_interval = (args.snapinterval, 'epoch')
    vis_interval = (args.vis_freq, 'iteration')
    plot_interval = (500, 'iteration')

    # Set up a trainer
    print("Preparing trainer...")
    trainer = training.Trainer(
        updater, (args.lrdecay_start + args.lrdecay_period, 'epoch'), out=out)
    for e in models:
        trainer.extend(extensions.snapshot_object(models[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)
    for e in optimizers:
        trainer.extend(extensions.snapshot_object(optimizers[e],
                                                  e + '{.updater.epoch}.npz'),
                       trigger=model_save_interval)

    ## log
    log_keys = ['epoch', 'iteration']
    log_keys_cycle = [
        'opt_g/loss_cycle_y', 'opt_f/loss_cycle_x', 'myval/cycle_y_l1',
        'myval/cycle_x_l1', 'opt_g/loss_tv'
    ]
    log_keys.extend(['myval/id_xy_grad', 'myval/id_xy_l1'])
    # 'myval/cycle_avgy_l1','myval/id_avgx_l1','myval/id_x_l2','myval/cycle_y_l2'
    log_keys_d = [
        'opt_x/loss_real', 'opt_x/loss_fake', 'opt_y/loss_real',
        'opt_y/loss_fake', 'opt_x/loss_gp', 'opt_y/loss_gp'
    ]
    log_keys_adv = ['opt_g/loss_adv', 'opt_f/loss_adv']
    log_keys.extend([
        'opt_g/loss_dom', 'opt_f/loss_dom', 'opt_g/loss_id', 'opt_f/loss_id',
        'opt_g/loss_idem', 'opt_f/loss_idem'
    ])
    log_keys.extend([
        'opt_g/loss_grad', 'opt_f/loss_grad', 'opt_g/loss_air',
        'opt_f/loss_air'
    ])

    log_keys_all = log_keys + log_keys_d + log_keys_adv + log_keys_cycle
    trainer.extend(
        extensions.LogReport(keys=log_keys_all, trigger=log_interval))
    trainer.extend(extensions.PrintReport(log_keys_all), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    ## to dump graph, set -lix 1 --warmup 0
    #    trainer.extend(extensions.dump_graph('opt_g/loss_id', out_name='gen.dot'))
    #    trainer.extend(extensions.dump_graph('opt_x/loss_real', out_name='dis.dot'))

    # ChainerUI
    trainer.extend(CommandsExtension())

    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(log_keys[2:],
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_d,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_d.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_adv,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_adv.png'))
        trainer.extend(
            extensions.PlotReport(log_keys_cycle,
                                  'iteration',
                                  trigger=plot_interval,
                                  file_name='loss_cyc.png'))
    ## visualisation
    os.makedirs(out, exist_ok=True)
    vis_folder = os.path.join(out, "vis")
    os.makedirs(vis_folder, exist_ok=True)

    ## output filenames of training dataset
    with open(os.path.join(out, 'trainA.txt'), 'w') as output:
        output.writelines("\n".join(train_A_dataset.ids))
    with open(os.path.join(out, 'trainB.txt'), 'w') as output:
        output.writelines("\n".join(train_B_dataset.ids))
    # archive the scripts
    rundir = os.path.dirname(os.path.realpath(__file__))
    import zipfile
    with zipfile.ZipFile(os.path.join(out, 'script.zip'),
                         'w',
                         compression=zipfile.ZIP_DEFLATED) as new_zip:
        for f in [
                'train.py', 'net.py', 'updater.py', 'consts.py', 'losses.py',
                'arguments.py', 'convert.py'
        ]:
            new_zip.write(os.path.join(rundir, f), arcname=f)


#    trainer.extend(visualize( (gen_g, gen_f), vis_folder, test_A_iter, test_B_iter), trigger=(1, 'epoch'))
    trainer.extend(VisEvaluator({
        "main": test_A_iter,
        "testB": test_B_iter
    }, {
        "gen_g": gen_g,
        "gen_f": gen_f
    },
                                params={
                                    'vis_out': vis_folder,
                                    'single_encoder': None
                                },
                                device=args.gpu[0]),
                   trigger=vis_interval)

    # Run the training
    trainer.run()
Esempio n. 25
0
                    type=str,
                    help='dataset file path')
parser.add_argument('--size',
                    '-s',
                    default=96,
                    type=int,
                    choices=[48, 96],
                    help='image size')
args = parser.parse_args()

image_size = args.size
if image_size == 48:
    gen_model = net.Generator48()
    dis_model = net.Discriminator48()
else:
    gen_model = net.Generator()
    dis_model = net.Discriminator()

optimizer_gen = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_gen.setup(gen_model)
optimizer_gen.add_hook(chainer.optimizer.WeightDecay(0.00001))
optimizer_dis = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_dis.setup(dis_model)
optimizer_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))

if args.input != None:
    serializers.load_hdf5(args.input + '.gen.model', gen_model)
    serializers.load_hdf5(args.input + '.gen.state', optimizer_gen)
    serializers.load_hdf5(args.input + '.dis.model', dis_model)
    serializers.load_hdf5(args.input + '.dis.state', optimizer_dis)
Esempio n. 26
0
def gan_training(args, train):
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    # Basically we define the batch size, number of processes which run together
    # in each iteration; it also needs a 'train' object
    if args.loaderjob:
        train_iter = chainer.iterators.MultiprocessIterator(
            train, args.batchsize, n_processes=args.loaderjob)
    else:
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Prepare Flow GAN model, defined in net.py
    gen = net.Generator(video_len=args.video_len)
    dis = net.Discriminator()

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()
    xp = np if args.gpu < 0 else cuda.cupy

    # Setup optimizer to use to minimize the loss function
    opt_gen = make_optimizer(gen, args)
    opt_dis = make_optimizer(dis, args)

    # Updater (updates parameters to train)
    updater = GAN_Updater(models=(gen, dis),
                          iterator=train_iter,
                          optimizer={
                              'gen': opt_gen,
                              'dis': opt_dis
                          },
                          device=args.gpu)

    # Trainer updates the params, including doing mini-batch loading, forward,
    # backward computations, and executing update formula
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    snapshot_interval = (args.snapshot_interval), 'iteration'
    visualize_interval = (args.visualize_interval), 'iteration'
    log_interval = (args.log_interval), 'iteration'

    # The Trainer class also invokes the extensions in decreasing order of priority
    trainer.extend(extensions.LogReport(trigger=log_interval))

    trainer.extend(
        extensions.PlotReport(['gen/loss', 'dis/loss'],
                              trigger=log_interval,
                              file_name='plot.png'))

    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'gen/loss', 'dis/loss']),
                   trigger=log_interval)

    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(extensions.snapshot(), trigger=snapshot_interval)

    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iteration_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)

    trainer.extend(extension(gen, args), trigger=visualize_interval)

    if args.adam_decay_iteration:
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_gen),
                       trigger=(args.adam_decay_iteration, 'iteration'))
        trainer.extend(extensions.ExponentialShift("alpha",
                                                   0.5,
                                                   optimizer=opt_dis),
                       trigger=(args.adam_decay_iteration, 'iteration'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()