Exemplo n.º 1
0
def see_image(args):
    celeba = CelebA('../../../local/CelebA/',
                    'img_align_celeba',
                    'list_attr_celeba.csv',
                    transform=transform)
    gen = Generator(140)
    gen = gen.to(device)

    z = torch.randn((args.show_size, 100)).to(device)
    y = randomsample(celeba, args.show_size).to(device)

    eps = [i for i in range(1, args.max_epoch) if i % 5 == 0]
    for ep in eps:
        path = '../model/gen_epoch_{}.model'.format(ep)
        gen.load_state_dict(torch.load(path))

        attrnames = list(celeba.df.columns)[1:]
        out = gen(z, y).detach().cpu()
        vutils.save_image(out, '../out/sample_out_ep{}.png'.format(ep))
        multihot = y.detach().cpu().numpy()
        idxs = np.where(multihot == 1)
        a = []
        piv = 0
    with open('../out/sample_attr.txt', 'w') as f:
        for i, j in zip(*idxs):
            if i == piv:
                a.append(attrnames[j])
            else:
                f.write('{}th image: {}\n'.format(piv, a))
                a = []
                piv += 1
Exemplo n.º 2
0
def cli(ctx, profile, snapshot):
    # load hyper-parameters
    hps = util.load_profile(profile)
    util.manual_seed(hps.ablation.seed)
    if snapshot is not None:
        hps.general.warm_start = True
        hps.general.pre_trained = snapshot

    # build graph
    builder = Builder(hps)
    state = builder.build(training=False)

    # load dataset
    dataset = CelebA(root=hps.dataset.root,
                     transform=transforms.Compose(
                         (transforms.CenterCrop(160), transforms.Resize(128),
                          transforms.ToTensor())))

    # start inference
    inferer = Inferer(hps=hps,
                      graph=state['graph'],
                      devices=state['devices'],
                      data_device=state['data_device'])
    ctx.obj['hps'] = hps
    ctx.obj['dataset'] = dataset
    ctx.obj['inferer'] = inferer
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU device ID')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='# of epoch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--g_hidden', type=int, default=128)
    parser.add_argument('--g_arch', type=int, default=1)
    parser.add_argument('--g_activate', type=str, default='sigmoid')
    parser.add_argument('--g_channel', type=int, default=512)
    parser.add_argument('--d_arch', type=int, default=1)
    parser.add_argument('--d_iters', type=int, default=5)
    parser.add_argument('--d_clip', type=float, default=0.01)
    parser.add_argument('--d_use_gamma', action='store_true', default=False)
    parser.add_argument('--d_channel', type=int, default=512)
    parser.add_argument('--initial_iter', type=int, default=10)
    parser.add_argument('--resume', default='')
    parser.add_argument('--out', default='')
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    # log directory
    out = datetime.datetime.now().strftime('%m%d%H%M')
    if args.out:
        out = out + '_' + args.out
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out))
    os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True)

    # hyper parameter
    with open(os.path.join(out_dir, 'setting.txt'), 'w') as f:
        for k, v in args._get_kwargs():
            print('{} = {}'.format(k, v))
            f.write('{} = {}\n'.format(k, v))

    # tensorboard
    if use_tensorboard:
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        summary_dir = os.path.join(out_dir, "summaries")
        loss_ = tf.placeholder(tf.float32)
        gen_loss_summary = tf.scalar_summary('gen_loss', loss_)
        dis_loss_summary = tf.scalar_summary('dis_loss', loss_)
        summary_writer = tf.train.SummaryWriter(summary_dir, sess.graph)

    # load celebA
    dataset = CelebA(image_type=args.g_activate)
    train_iter = chainer.iterators.MultiprocessIterator(
        dataset, args.batch_size)

    if args.g_arch == 1:
        gen = wgan.Generator(n_hidden=args.g_hidden,
                             activate=args.g_activate,
                             ch=args.g_channel)
    elif args.g_arch == 2:
        gen = wgan.Generator2(n_hidden=args.g_hidden,
                              activate=args.g_activate,
                              ch=args.g_channel)
    else:
        raise ValueError('invalid arch')

    if args.d_arch == 1:
        dis = wgan.Discriminator(ch=args.d_channel, use_gamma=args.d_use_gamma)
    elif args.d_arch == 2:
        dis = wgan.Discriminator2(ch=args.d_channel)
    else:
        raise ValueError('invalid arch')

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    optimizer_gen = chainer.optimizers.RMSprop(lr=0.00005)
    optimizer_dis = chainer.optimizers.RMSprop(lr=0.00005)

    optimizer_gen.setup(gen)
    optimizer_dis.setup(dis)

    # optimizer_gen.add_hook(chainer.optimizer.WeightDecay(0.00001))
    # optimizer_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))

    # start training
    start = time.time()
    gen_iterations = 0
    for epoch in range(args.epoch):

        # train
        sum_L_gen = []
        sum_L_dis = []

        i = 0
        while i < len(dataset) // args.batch_size:

            # tips for critic reach optimality
            if gen_iterations < args.initial_iter or gen_iterations % 500 == 0:
                d_iters = 100
            else:
                d_iters = args.d_iters

            ############################
            # (1) Update D network
            ###########################
            j = 0
            while j < d_iters:
                batch = train_iter.next()
                x = chainer.Variable(
                    gen.xp.asarray([b[0] for b in batch], 'float32'))
                # attr = chainer.Variable(gen.xp.asarray([b[1] for b in batch], 'int32'))

                # real
                y_real = dis(x)

                # fake
                z = chainer.Variable(gen.xp.asarray(
                    gen.make_hidden(args.batch_size)),
                                     volatile=True)
                x_fake = gen(z)
                x_fake.volatile = False
                y_fake = dis(x_fake)

                # calc EMD
                L_dis = -(y_real - y_fake)
                dis.cleargrads()
                L_dis.backward()
                optimizer_dis.update()

                dis.clip_weight(clip=args.d_clip)

                j += 1
                i += 1

            ###########################
            # (2) Update G network
            ###########################
            z = chainer.Variable(
                gen.xp.asarray(gen.make_hidden(args.batch_size)))
            x_fake = gen(z)
            y_fake = dis(x_fake)
            L_gen = -y_fake
            gen.cleargrads()
            L_gen.backward()
            optimizer_gen.update()

            gen_iterations += 1

            emd = float(-L_dis.data)
            l_gen = float(L_gen.data)
            sum_L_dis.append(emd)
            sum_L_gen.append(l_gen)

            progress_report(epoch * len(dataset) // args.batch_size + i, start,
                            args.batch_size, emd)

            if use_tensorboard:
                summary = sess.run(gen_loss_summary, feed_dict={loss_: l_gen})
                summary_writer.add_summary(summary, gen_iterations)
                summary = sess.run(dis_loss_summary, feed_dict={loss_: emd})
                summary_writer.add_summary(summary, gen_iterations)

        log = 'gen loss={:.5f}, dis loss={:.5f}'.format(
            np.mean(sum_L_gen), np.mean(sum_L_dis))
        print('\n' + log)
        with open(os.path.join(out_dir, "log"), 'a+') as f:
            f.write(log + '\n')

        if epoch % 5 == 0:
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.dis.model".format(epoch)), dis)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.gen.model".format(epoch)), gen)

        visualize(gen,
                  epoch=epoch,
                  savedir=os.path.join(out_dir, 'visualize'),
                  image_type=args.g_activate)
optimizerQ = torch.optim.RMSprop(Dis.modelQ.parameters(),
                                 lr=opt.lrQ,
                                 alpha=0.9,
                                 eps=1e-8)

lrSchedulerE = torch.optim.lr_scheduler.ExponentialLR(optimizerE, opt.decayLr)
lrSchedulerD = torch.optim.lr_scheduler.ExponentialLR(optimizerD, opt.decayLr)
lrSchedulerDis = torch.optim.lr_scheduler.ExponentialLR(
    optimizerDis, opt.decayLr)
lrSchedulerQ = torch.optim.lr_scheduler.ExponentialLR(optimizerQ, opt.decayLr)

# Path to data.
imgDirectory = "../../Downloads/img_align_celeba/"

# Create dataloader.
dataloader = DataLoader(CelebA(imgDirectory),
                        batch_size=opt.batchSize,
                        shuffle=True,
                        num_workers=opt.workers)

testBatch = next(iter(dataloader)).type(Tensor)
torchvision.utils.save_image(testBatch,
                             opt.outDir + '/testBatch.png',
                             normalize=True)
testLatent = torch.randn(64, opt.latentSize).type(Tensor)

# Lists to keep track of progress.
lossesKl = []  # KL divergence of the encoded distribution from the prior.
lossesD = []  # Decoder/Generator losses from the GAN objective.
lossesDis = []  # Discriminator losses from the GAN objective.
lossesDisL = [
Exemplo n.º 5
0
print("Application Params: ", args)
print("Using GPU(s): ", args.gpu)
print("Running Mode:", args.mode)
print(" - Initializing Networks...")
decoder = Decoder(args)
encoder = Encoder(args)
generator = Generator(args, decoder)
discriminator = Discriminator(args, encoder)
adjuster = Adjuster(args, discriminator, generator)

if args.mode == "train":
    repo = Repo(".")
    if repo.is_dirty() and not args.debug:  # 程序被修改且不是测试模式
        raise EnvironmentError(
            "Git repo is Dirty! Please train after committed.")
    data = CelebA(args)
    print("Using Attribute:", data.label)
    model = EagerTrainer(args, generator, discriminator, adjuster, data)
    model.train()
elif args.mode == "visual":  # loss etc的可视化
    print("The result path is ", path.join(args.result_dir, "log"))
    system("tensorboard --host 0.0.0.0 --logdir " +
           path.join(args.result_dir, "log"))
elif args.mode == "plot":
    args.reuse = True
    model = EagerTrainer(args, generator, discriminator, adjuster, None)
    model.plot()  # 输出模型结构图
elif args.mode == "random-sample":
    args.reuse = True
    data = CelebA(args)
    print("Using Attribute:", data.label)
Exemplo n.º 6
0
from dataset import CelebA
from models import alexnet
from models import resnet18

from args import args

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

if __name__ == '__main__':
    # Data
    print('==> Preparing data ...')

    train_set = CelebA(data_root=args.data_dir, attr=args.attr, split='train')
    val_set = CelebA(data_root=args.data_dir, attr=args.attr, split='val')

    train_loader = DataLoader(train_set,
                              args.batch_size,
                              drop_last=True,
                              shuffle=False,
                              num_workers=args.workers)
    val_loader = DataLoader(val_set,
                            args.batch_size,
                            drop_last=True,
                            shuffle=False,
                            num_workers=args.workers)
    classes = ('yes', 'no')

    # Model
Exemplo n.º 7
0
from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping
tf.enable_eager_execution()

batch_size = 64
num_epochs = 40
img_height = 32
img_width = 32

celeba = CelebA(drop_features=[
    '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
    'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
    'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
    'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
    'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
    'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
    'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
    'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
    'Wearing_Necktie'
])

train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   samplewise_center=True,
                                   samplewise_std_normalization=True)
val_datagen = ImageDataGenerator(rescale=1. / 255,
                                 samplewise_center=True,
                                 samplewise_std_normalization=True)
test_datagen = ImageDataGenerator(rescale=1. / 255,
                                  samplewise_center=True,
                                  samplewise_std_normalization=True)
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU device ID')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='# of epoch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--g_hidden', type=int, default=128)
    parser.add_argument('--g_arch', type=int, default=1)
    parser.add_argument('--g_activate', type=str, default='sigmoid')
    parser.add_argument('--g_channel', type=int, default=512)
    parser.add_argument('--C', type=float, default=1)
    parser.add_argument('--d_arch', type=int, default=1)
    parser.add_argument('--d_iters', type=int, default=5)
    parser.add_argument('--d_clip', type=float, default=0.01)
    parser.add_argument('--d_channel', type=int, default=512)
    parser.add_argument('--initial_iter', type=int, default=10)
    parser.add_argument('--resume', default='')
    parser.add_argument('--out', default='')
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    # log directory
    out = datetime.datetime.now().strftime('%m%d%H%M')
    if args.out:
        out = out + '_' + args.out
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out))
    os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True)

    # hyper parameter
    with open(os.path.join(out_dir, 'setting.txt'), 'w') as f:
        for k, v in args._get_kwargs():
            print('{} = {}'.format(k, v))
            f.write('{} = {}\n'.format(k, v))

    # tensorboard
    if use_tensorboard:
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        summary_dir = os.path.join(out_dir, "summaries")
        loss_ = tf.placeholder(tf.float32)
        gen_loss_summary = tf.scalar_summary('gen_loss', loss_)
        dis_loss_summary = tf.scalar_summary('dis_loss', loss_)
        enc_loss_summary = tf.scalar_summary('enc_loss', loss_)
        rec_loss_summary = tf.scalar_summary('rec_loss', loss_)
        summary_writer = tf.train.SummaryWriter(summary_dir, sess.graph)

    # load celebA
    dataset = CelebA(image_type=args.g_activate)
    train_iter = chainer.iterators.MultiprocessIterator(
        dataset, args.batch_size)

    gen = vaewgan.Generator(n_hidden=args.g_hidden,
                            activate=args.g_activate,
                            ch=args.g_channel)
    if args.d_arch == 1:
        dis = vaewgan.Discriminator(ch=args.d_channel)
    elif args.d_arch == 2:
        dis = vaewgan.Discriminator2(ch=args.d_channel)
    else:
        raise ValueError('invalid arch')
    enc = vaewgan.Encoder(n_hidden=args.g_hidden)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()
        enc.to_gpu()

    optimizer_gen = chainer.optimizers.RMSprop(lr=0.00005)
    optimizer_dis = chainer.optimizers.RMSprop(lr=0.00005)
    optimizer_enc = chainer.optimizers.RMSprop(lr=0.00005)

    optimizer_gen.setup(gen)
    optimizer_dis.setup(dis)
    optimizer_enc.setup(enc)

    # optimizer_gen.add_hook(chainer.optimizer.WeightDecay(0.00001))
    # optimizer_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))
    # optimizer_enc.add_hook(chainer.optimizer.WeightDecay(0.00001))

    # start training
    start = time.time()
    gen_iterations = 0
    for epoch in range(args.epoch):

        # train
        sum_L_gen = []
        sum_L_dis = []
        sum_L_enc = []
        sum_L_rec = []

        i = 0
        while i < len(dataset) // args.batch_size:

            # tips for critic reach optimality
            if gen_iterations < args.initial_iter or gen_iterations % 500 == 0:
                d_iters = 100
            else:
                d_iters = args.d_iters

            ############################
            # (1) Update D network
            ###########################
            j = 0
            while j < d_iters:
                batch = train_iter.next()
                x = chainer.Variable(
                    gen.xp.asarray([b[0] for b in batch], 'float32'))
                # attr = chainer.Variable(gen.xp.asarray([b[1] for b in batch], 'int32'))

                # real image
                y_real, _, _ = dis(x)

                # fake image from random noize
                z = chainer.Variable(gen.xp.asarray(
                    gen.make_hidden_normal(args.batch_size)),
                                     volatile=True)
                x_fake = gen(z)
                x_fake.volatile = False
                y_fake, _, _ = dis(x_fake)

                # fake image from reconstruction
                x.volatile = True
                mu_z, ln_var_z = enc(x)
                z_rec = F.gaussian(mu_z, ln_var_z)
                x_rec = gen(z_rec)
                x_rec.volatile = False
                y_rec, _, _ = dis(x_rec)

                L_dis = -(y_real - 0.5 * y_fake - 0.5 * y_rec)
                # print(j, -L_dis.data)

                dis.cleargrads()
                L_dis.backward()
                optimizer_dis.update()

                dis.clip_weight(clip=args.d_clip)

                j += 1
                i += 1

            ###########################
            # (2) Update Enc and Dec network
            ###########################
            batch = train_iter.next()
            x = chainer.Variable(
                gen.xp.asarray([b[0] for b in batch], 'float32'))
            z = chainer.Variable(
                gen.xp.random.normal(
                    0, 1, (args.batch_size, args.g_hidden)).astype(np.float32))

            # real image
            y_real, l2_real, l3_real = dis(x)

            # fake image from random noize
            x_fake = gen(z)
            y_fake, _, _ = dis(x_fake)

            # fake image from reconstruction
            mu_z, ln_var_z = enc(x)
            z_rec = F.gaussian(mu_z, ln_var_z)
            x_rec = gen(z_rec)
            y_rec, l2_rec, l3_rec = dis(x_rec)

            L_rec = F.mean_squared_error(
                l2_real,
                l2_rec) * l2_real.data.shape[2] * l2_real.data.shape[3]
            L_prior = F.gaussian_kl_divergence(mu_z,
                                               ln_var_z) / args.batch_size
            L_gan = -0.5 * y_fake - 0.5 * y_rec

            L_gen = L_gan + args.C * L_rec
            L_enc = L_rec + L_prior

            gen.cleargrads()
            L_gen.backward()
            optimizer_gen.update()

            enc.cleargrads()
            L_enc.backward()
            optimizer_enc.update()

            gen_iterations += 1

            l_dis = float(-L_dis.data)
            l_gen = float(L_gen.data)
            l_enc = float(L_enc.data)
            l_rec = float(L_rec.data)
            sum_L_dis.append(l_dis)
            sum_L_gen.append(l_gen)
            sum_L_enc.append(l_enc)
            sum_L_rec.append(l_rec)

            progress_report(epoch * len(dataset) // args.batch_size + i, start,
                            args.batch_size, l_dis)

            if use_tensorboard:
                summary = sess.run(gen_loss_summary, feed_dict={loss_: l_gen})
                summary_writer.add_summary(summary, gen_iterations)
                summary = sess.run(dis_loss_summary, feed_dict={loss_: l_dis})
                summary_writer.add_summary(summary, gen_iterations)
                summary = sess.run(enc_loss_summary, feed_dict={loss_: l_enc})
                summary_writer.add_summary(summary, gen_iterations)
                summary = sess.run(rec_loss_summary, feed_dict={loss_: l_rec})
                summary_writer.add_summary(summary, gen_iterations)

        log = 'gen loss={:.5f} dis loss={:.5f} enc loss={:.5f} rec loss={:.5f}' \
            .format(np.mean(sum_L_gen), np.mean(sum_L_dis), np.mean(sum_L_enc), np.mean(sum_L_rec))
        print('\n' + log)
        with open(os.path.join(out_dir, "log"), 'a+') as f:
            f.write(log + '\n')

        if epoch % 5 == 0:
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.dis.model".format(epoch)), dis)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.gen.model".format(epoch)), gen)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.enc.model".format(epoch)), enc)

        visualize(gen,
                  enc,
                  train_iter,
                  epoch=epoch,
                  savedir=os.path.join(out_dir, 'visualize'),
                  image_type=args.g_activate)
Exemplo n.º 9
0
    enc_z.load_state_dict(torch.load(args.enc_z_path))
    gen.load_state_dict(torch.load(args.gen_path))

    enc_y.eval()
    enc_z.eval()
    gen.eval()

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    test_dataset = CelebA(args.root_dir,
                          args.img_dir,
                          args.ann_dir,
                          transform=transform,
                          train=False)
    testloader = DataLoader(test_dataset,
                            batch_size=args.show_size,
                            shuffle=True,
                            collate_fn=collate_fn,
                            drop_last=True)
    for i in testloader:
        sample = i
        break
    im = sample['image']

    grid = vutils.make_grid(im, normalize=True)
    for i in range(100):
        if not os.path.exists(os.path.join(args.save_path, str(i))):
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU device ID')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='# of epoch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=100,
                        help='learning minibatch size')
    parser.add_argument('--g_hidden', type=int, default=128)
    parser.add_argument('--g_arch', type=int, default=1)
    parser.add_argument('--g_activate', type=str, default='sigmoid')
    parser.add_argument('--g_channel', type=int, default=512)
    parser.add_argument('--C', type=float, default=1)
    parser.add_argument('--d_arch', type=int, default=1)
    parser.add_argument('--d_iters', type=int, default=5)
    parser.add_argument('--d_clip', type=float, default=0.01)
    parser.add_argument('--d_channel', type=int, default=512)
    parser.add_argument('--initial_iter', type=int, default=10)
    parser.add_argument('--resume', default='')
    parser.add_argument('--out', default='')
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    # log directory
    out = datetime.datetime.now().strftime('%m%d%H%M')
    if args.out:
        out = out + '_' + args.out
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out))
    os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True)

    # hyper parameter
    with open(os.path.join(out_dir, 'setting.txt'), 'w') as f:
        for k, v in args._get_kwargs():
            print('{} = {}'.format(k, v))
            f.write('{} = {}\n'.format(k, v))

    # tensorboard
    if use_tensorboard:
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        summary_dir = os.path.join(out_dir, "summaries")
        loss_ = tf.placeholder(tf.float32)
        gen_loss_summary = tf.scalar_summary('gen_loss', loss_)
        dis_loss_summary = tf.scalar_summary('dis_loss', loss_)
        enc_loss_summary = tf.scalar_summary('enc_loss', loss_)
        rec_loss_summary = tf.scalar_summary('rec_loss', loss_)
        summary_writer = tf.train.SummaryWriter(summary_dir, sess.graph)

    # load celebA
    dataset = CelebA(image_type=args.g_activate)
    train_iter = chainer.iterators.MultiprocessIterator(
        dataset, args.batch_size)

    gen = vaegan.Generator(n_hidden=args.g_hidden,
                           activate=args.g_activate,
                           ch=args.g_channel)
    dis = vaegan.Discriminator(ch=args.d_channel)
    enc = vaegan.Encoder(n_hidden=args.g_hidden)

    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        enc.to_gpu()
        gen.to_gpu()
        dis.to_gpu()

    optimizer_enc = chainer.optimizers.Adam(alpha=0.0001, beta1=0.5)
    optimizer_gen = chainer.optimizers.Adam(alpha=0.0001, beta1=0.5)
    optimizer_dis = chainer.optimizers.Adam(alpha=0.0001, beta1=0.5)

    optimizer_enc.setup(enc)
    optimizer_gen.setup(gen)
    optimizer_dis.setup(dis)

    optimizer_enc.add_hook(chainer.optimizer.WeightDecay(0.00001))
    optimizer_gen.add_hook(chainer.optimizer.WeightDecay(0.00001))
    optimizer_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))

    if args.init_epoch is not None:
        serializers.load_hdf5(args.input + '.enc.model', enc)
        serializers.load_hdf5(args.input + '.enc.state', optimizer_enc)
        serializers.load_hdf5(args.input + '.gen.model', gen)
        serializers.load_hdf5(args.input + '.gen.state', optimizer_gen)
        serializers.load_hdf5(args.input + '.dis.model', dis)
        serializers.load_hdf5(args.input + '.dis.state', optimizer_dis)

    start_time = time.time()
    train_count = 0
    C = 1
    for epoch in range(1, args.epoch + 1):
        print('Epoch {}'.format(epoch))

        sum_L_gen = []
        sum_L_dis = []
        sum_L_enc = []
        sum_L_rec = []

        loop = 1000
        for i in range(loop):
            batch = train_iter.next()
            x = chainer.Variable(
                gen.xp.asarray([b[0] for b in batch], 'float32'))
            # attr = chainer.Variable(gen.xp.asarray([b[1] for b in batch], 'int32'))

            # real image
            y_real, l2_real, l3_real = dis(x)
            L_dis_real = F.softmax_cross_entropy(
                y_real,
                chainer.Variable(
                    gen.xp.zeros(args.batch_size).astype(np.int32)))

            # fake image from random noize
            z = chainer.Variable(
                gen.xp.asarray(gen.make_hidden_normal(args.batch_size)))
            x_fake = gen(z)
            y_fake, _, _ = dis(x_fake)
            L_gen_fake = F.softmax_cross_entropy(
                y_fake,
                chainer.Variable(
                    gen.xp.zeros(args.batch_size).astype(np.int32)))
            L_dis_fake = F.softmax_cross_entropy(
                y_fake,
                chainer.Variable(
                    gen.xp.ones(args.batch_size).astype(np.int32)))

            # fake image from reconstruction
            mu_z, ln_var_z = enc(x)
            z_rec = F.gaussian(mu_z, ln_var_z)
            x_rec = gen(z_rec)
            y_rec, l2_rec, l3_rec = dis(x_rec)

            L_prior = F.gaussian_kl_divergence(mu_z,
                                               ln_var_z) / args.batch_size
            L_gen_rec = F.softmax_cross_entropy(
                y_rec,
                chainer.Variable(
                    gen.xp.zeros(args.batch_size).astype(np.int32)))
            L_dis_rec = F.softmax_cross_entropy(
                y_rec,
                chainer.Variable(
                    gen.xp.ones(args.batch_size).astype(np.int32)))

            L_rec = F.mean_squared_error(
                l2_real,
                l2_rec) * l2_real.data.shape[2] * l2_real.data.shape[3]
            # L_rec = F.gaussian_nll(l0, l2, chainer.Variable(xp.zeros(l0.data.shape).astype('float32'))) / args.batch_size

            # calc loss
            L_dis = L_dis_real + 0.5 * L_dis_fake + 0.5 * L_dis_rec
            L_enc = L_prior + L_rec
            L_gen = L_gen_fake + L_gen_rec + args.C * L_rec

            # update
            optimizer_enc.target.cleargrads()
            L_enc.backward()
            optimizer_enc.update()

            optimizer_gen.target.cleargrads()
            L_gen.backward()
            optimizer_gen.update()

            optimizer_dis.target.cleargrads()
            L_dis.backward()
            optimizer_dis.update()

            train_count += 1

            l_dis = float(L_dis.data)
            l_gen = float(L_gen.data)
            l_enc = float(L_enc.data)
            l_rec = float(L_rec.data)
            sum_L_dis.append(l_dis)
            sum_L_gen.append(l_gen)
            sum_L_enc.append(l_enc)
            sum_L_rec.append(l_rec)

            progress_report(train_count, start_time, args.batch_size, l_dis)

            if use_tensorboard:
                summary = sess.run(gen_loss_summary, feed_dict={loss_: l_gen})
                summary_writer.add_summary(summary, train_count)
                summary = sess.run(dis_loss_summary, feed_dict={loss_: l_dis})
                summary_writer.add_summary(summary, train_count)
                summary = sess.run(enc_loss_summary, feed_dict={loss_: l_enc})
                summary_writer.add_summary(summary, train_count)
                summary = sess.run(rec_loss_summary, feed_dict={loss_: l_rec})
                summary_writer.add_summary(summary, train_count)

        log = 'gen loss={:.5f}, dis loss={:.5f} enc loss={:.5f} rec loss={:.5f}' \
            .format(np.mean(sum_L_gen), np.mean(sum_L_dis), np.mean(sum_L_enc), np.mean(sum_L_rec))
        print('\n' + log)
        with open(os.path.join(out_dir, "log"), 'a+') as f:
            f.write(log + '\n')

        if epoch % 5 == 0:
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.dis.model".format(epoch)), dis)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.gen.model".format(epoch)), gen)
            serializers.save_hdf5(
                os.path.join(out_dir, "models",
                             "{:03d}.enc.model".format(epoch)), enc)

        visualize(gen,
                  enc,
                  train_iter,
                  epoch=epoch,
                  savedir=os.path.join(out_dir, 'visualize'),
                  image_type=args.g_activate)
Exemplo n.º 11
0
args = parser.parse_args()

c_dims = len(args.selected_attrs)
num_epochs = args.epochs
lamb_cls = args.lam_cls
lamb_rec = 10

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.CenterCrop(178),
    transforms.Resize(size=64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CelebA(root=args.directory,
                 attributes=args.selected_attrs,
                 transform=transform,
                 download=args.download)

data_loader = loader.DataLoader(dataset, batch_size=args.batch_size)


def fakeLabels(lth):
    """
    lth (int): no of labels required
    """
    label = torch.tensor([])
    for i in range(lth):
        arr = np.zeros(c_dims)
        arr[0] = 1
        np.random.shuffle(arr)
        label = torch.cat((label, torch.tensor(arr).float().unsqueeze(0)),
Exemplo n.º 12
0
from matplotlib import pyplot as plt
import cv2


def test_instance(db=None):
    instance_idx = 39
    image, label, (bbox_tl_x, bbox_tl_y, bbox_br_x,
                   bbox_br_y) = db[instance_idx]
    print('Instance Index: {}'.format(instance_idx))
    print('Label: {}'.format(label))
    print('Image Shape (Width × Height): ({} × {})'.format(
        image.shape[1], image.shape[0]))
    print('Facial Shape (Width × Height): ({} × {})'.format(
        bbox_br_x - bbox_tl_x, bbox_br_y - bbox_tl_y))
    print('Facial Area: {}'.format(
        (bbox_br_x - bbox_tl_x) * (bbox_br_y - bbox_tl_y)))

    plt.figure(num='Dataset Tester', figsize=(10, 10))
    plt.title('Identity: {}'.format(label))
    image = cv2.rectangle(image, (int(bbox_tl_x), int(bbox_tl_y)),
                          (int(bbox_br_x), int(bbox_br_y)),
                          color=(1, 0, 0),
                          thickness=2)
    plt.imshow(image, cmap='gray')
    plt.show()


db = CelebA(dataset_type=CelebA.VALIDATION)
# db.shuffle()
test_instance(db)
Exemplo n.º 13
0
def train(args):

    device = torch.device('cuda' if torch.cuda.is_available() and args.enable_cuda else 'cpu')

    # transforms applied
    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if args.dataset == 'celeba':
        img_dir = 'cropped'
        ann_dir = 'list_attr_celeba.csv'
        train_dataset = CelebA(args.root_dir, img_dir, ann_dir, transform=transform, train=True)
        test_dataset = CelebA(args.root_dir, img_dir, ann_dir, transform=transform, train=False)
        attnames = list(train_dataset.df.columns)[1:]
    elif args.dataset == 'sunattributes':
        img_dir = 'cropped'
        ann_dir = 'SUNAttributeDB'
        train_dataset = SUN_Attributes(args.root_dir, img_dir, ann_dir, transform=transform, train=True)
        test_dataset = SUN_Attributes(args.root_dir, img_dir, ann_dir, transform=transform, train=False)
        attnames = train_dataset.attrnames
    else:
        raise NotImplementedError()

    fsize = train_dataset.feature_size

    trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
    testloader = DataLoader(test_dataset, batch_size=args.show_size, shuffle=True, collate_fn=collate_fn, drop_last=True)

    '''
    dataloader returns dictionaries.
    sample : {'image' : (bs, 64, 64, 3), 'attributes' : (bs, att_size)}
    '''

    # model, optimizer, criterion
    if args.residual:
        gen = Generator_Res(in_c = args.nz + fsize)
        dis = Discriminator_Res(ny=fsize)
    else:
        gen = Generator(in_c = args.nz + fsize)
        dis = Discriminator(ny=fsize)
    MODELPATH = '../model/celeba/res_False/gen_epoch_{}.ckpt'.format(args.model_ep)
    gen.load_state_dict(torch.load(MODELPATH))

    enc_y = Encoder(fsize).to(device)
    MODELPATH = '../model/celeba/res_False/enc_y_epoch_{}.ckpt'.format(args.enc_ep)
    enc_y.load_state_dict(torch.load(MODELPATH))
    enc_z = Encoder(args.nz, for_y=False).to(device)
    MODELPATH = '../model/celeba/res_False/enc_y_epoch_{}.ckpt'.format(args.enc_ep)
    enc_z.load_state_dict(torch.load(MODELPATH))

    gen.eval()
    enc_y.eval()
    enc_z.eval()

    model = AttrEncoder(outdims=fsize).to(device)
    # initialize weights for encoders

    att_optim = optim.Adam(model.parameters(), lr=args.learning_rate, betas=args.betas)
    criterion = nn.BCELoss().to(device)

    noise = torch.randn((args.batch_size, args.nz)).to(device)

    if args.use_tensorboard:
        writer.add_text("Text", "begin training, lr={}".format(args.learning_rate))
    print("begin training, lr={}".format(args.learning_rate), flush=True)
    stepcnt = 0

    for ep in range(args.num_epoch):
        for it, sample in enumerate(trainloader):
            elapsed = time.time()
            image = sample['image'].to(device)
            att = sample['attributes'].to(device)

            out = model(image)
            loss = criterion(out, att)
            z = enc_z(image)
            y = enc_y(image)
            out2 = gen(z, y)
            loss2 = criterion(out2, att)

            loss.backward()
            loss2.backward()
            att_optim.step()
            if it % args.log_every == (args.log_every - 1):
                if args.use_tensorboard:
                    writer.add_scalar('loss', loss, it+1)
                print("{}th iter \t loss: {:.8f} \t time per iter: {:.05f}s".format(it+1, loss.detach().cpu(), (time.time() - elapsed) / args.log_every), flush=True)

        cnt, allcnt = eval(model, testloader, device)
        print("-" * 50)
        print("epoch {} done. accuracy: {:.03f}%. num guessed: [{:05d}/{:05d}]".format(ep+1, cnt / allcnt * 100, cnt, allcnt))
        print("-" * 50, flush=True)

        if ep % args.save_every == (args.save_every - 1):
            torch.save(model.state_dict(), "../model/atteval_epoch_{}.model".format(ep+1))

    cnt, allcnt = eval(model, testloader, device, early=False)
    print("-" * 50)
    print("training done. final accuracy: {:.03f}% num guessed: [{:07d}/{:07d}]".format(ep+1, cnt / allcnt * 100, cnt, allcnt))
    print("-" * 50, flush=True)
Exemplo n.º 14
0
 def test_celeba(self):
     dataset = CelebA(root='/Data/CelebA')
Exemplo n.º 15
0
def train(args):

    print(args)

    if args.use_tensorboard:
        log_name = "{}_im{}_lr{}_bs{}_nz{}_aux".format(args.dataset, args.image_size, args.learning_rate, args.batch_size, args.nz)
        print("logging to runs/{}".format(log_name))
        writer = SummaryWriter(log_dir=os.path.join('runs', log_name))

    device = torch.device('cuda' if torch.cuda.is_available() and args.enable_cuda else 'cpu')

    # transforms applied
    transform = transforms.Compose([
                        transforms.Resize((args.image_size, args.image_size)),
                        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # dataset and dataloader for training
    if args.dataset == 'celeba':
        img_dir = 'img_align_celeba'
        ann_dir = 'list_attr_celeba.csv'
        train_dataset = CelebA(args.root_dir, img_dir, ann_dir, transform=transform)
        attnames = list(train_dataset.df.columns)[1:]
    elif args.dataset == 'sunattributes':
        img_dir = 'images'
        ann_dir = 'SUNAttributeDB'
        train_dataset = SUN_Attributes(args.root_dir, img_dir, ann_dir, transform=transform)
        attnames = train_dataset.attrnames
    else:
        raise NotImplementedError()

    fsize = train_dataset.feature_size

    trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=8)

    '''
    dataloader returns dictionaries.
    sample : {'image' : (bs, H, W, 3), 'attributes' : (bs, fsize)}
    '''


    # model, optimizer, criterion
    gen = Generator(in_c = args.nz + fsize)
    dis = Discriminator_Aux(ny=fsize)
    gen_params = 0
    dis_params = 0
    for params in list(gen.parameters())[::2]:
        gen_params += params.numel()
    for params in list(dis.parameters())[::2]:
        dis_params += params.numel()
    print("# of parameters in generator : {}".format(gen_params))
    print("# of parameters in discriminator : {}".format(dis_params))
    gen = gen.to(device)
    dis = dis.to(device)

    # initialize weights for conv layers
    gen.apply(init_weights)
    dis.apply(init_weights)

    if args.opt_method == 'Adam':
        gen_optim = optim.Adam(gen.parameters(), lr=args.learning_rate, betas=args.betas)
        dis_optim = optim.Adam(dis.parameters(), lr=args.learning_rate, betas=args.betas)
    if args.opt_method == 'SGD':
        gen_optim = optim.SGD(gen.parameters(), lr=args.learning_rate, momentum=args.momentum)
        dis_optim = optim.SGD(dis.parameters(), lr=args.learning_rate, momentum=args.momentum)
        gen_scheduler = optim.lr_scheduler.ReduceLROnPlateau(gen_optim, patience=args.patience)
        dis_scheduler = optim.lr_scheduler.ReduceLROnPlateau(dis_optim, patience=args.patience)
    criterion = nn.BCELoss()

    fixed_noise = torch.randn(args.show_size, args.nz).to(device)
    # fixed_label = torch.zeros(args.show_size, fsize).to(device)
    fixed_label = randomsample(train_dataset, args.show_size).to(device)

    label = torch.empty(args.batch_size).to(device)
    noise = torch.empty(args.batch_size, args.nz).to(device)

    if not os.path.exists(args.model_path):
        os.mkdir(args.model_path)
    if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)

    if args.use_tensorboard:
        writer.add_text("Text", "begin training, lr={}".format(args.learning_rate))
    print("begin training, lr={}".format(args.learning_rate), flush=True)
    stepcnt = 0
    gen.train()
    dis.train()

    real_tar = 0.9
    fake_tar = 0.1

    for ep in range(args.num_epoch):

        Dloss = 0
        Gloss = 0
        Closs = 0

        for it, sample in enumerate(trainloader):

            elapsed = time.time()

            x = sample['image'].to(device)
            y = sample['attributes'].to(device)

            '''training of Discriminator'''
            # train on real images and real attributes, target are ones (real)
            dis.zero_grad()
            real_label = torch.full_like(label, real_tar)
            x_real = x.clone().detach()
            y_real = y.clone().detach()
            out, y = dis(x_real)
            # print(out.size(), real_label.size())
            real_dis_loss = 0.5 * criterion(out, real_label)
            real_cls_loss = 0.5 * criterion(y, y_real)
            err = real_dis_loss + args.lambda_c * real_cls_loss
            err.backward()
            D_x = out.mean().cpu().item()
            cls1 = real_cls_loss.cpu().item()

            # train on fake images and real attributes, target are zeros (fake)
            fake_label = torch.full_like(label, fake_tar)
            z = torch.randn_like(noise)
            y_real = y.clone().detach()
            x_fake = gen(z, y_real)
            out, y = dis(x_fake)
            fake_dis_loss = 0.5 * criterion(out, fake_label)
            fake_cls_loss = 0.5 * criterion(y, y_real)
            err = fake_dis_loss + args.lambda_c * fake_cls_loss
            err.backward()
            D_G_z1 = out.mean().cpu().item()
            cls2 = real_cls_loss.cpu().item()

            dis_optim.step()
            errD = real_dis_loss + fake_dis_loss
            errD_C = cls1 + cls2

            '''training of Generator'''
            # train on fake images and real attributes, target are ones (real)
            gen.zero_grad()
            y_real = randomsample(train_dataset, args.batch_size).to(device)
            z = torch.randn_like(noise)
            x_prime = gen(z, y_real)
            out, y = dis(x_prime)
            real_label = torch.full_like(label, real_tar)
            gen_loss = criterion(out, real_label)
            gen_cls_loss = criterion(y, y_real)
            err = gen_loss + args.lambda_c * gen_cls_loss
            err.backward()
            D_G_z2 = out.mean().cpu().item()

            gen_optim.step()
            errG = gen_loss.cpu().item()
            errG_C = gen_cls_loss.cpu().item()

            dloss = errD
            gloss = errG
            closs = errD_C + errG_C
            Dloss += dloss
            Gloss += gloss
            Closs += closs

            '''log the losses and images, get time of loop'''
            if it % args.log_every == (args.log_every - 1):

                if args.use_tensorboard:
                    writer.add_scalar("D_x", D_x, stepcnt)
                    writer.add_scalar("D_G_z1", D_G_z1, stepcnt)
                    writer.add_scalar("D_G_z2", D_G_z2, stepcnt)
                    writer.add_scalar("D_loss", dloss, stepcnt)
                    writer.add_scalar("G_loss", gloss, stepcnt)
                    writer.add_scalar("C_loss", closs, stepcnt)

                after = time.time()
                print("{}th iter\tD(x, y): {:.4f}\tD(G(z, y))_1: {:.4f}\tD(G(z, y))_2: {:.4f}\tcls_loss: {:.4f}\t{:.4f}s per loop".format(it+1, D_x, D_G_z1, D_G_z2, closs, (after-elapsed) / args.log_every), flush=True)

            if it % args.image_every == (args.image_every - 1):
                im = gen(fixed_noise, fixed_label)
                grid = vutils.make_grid(im, normalize=True)

                if args.use_tensorboard:
                    writer.add_image('iter {}'.format(it+1), grid, stepcnt+1)
                    print("logged image, iter {}".format(it+1))

            stepcnt += 1

        # learning rate scheduler
        if args.opt_method is 'SGD':
            gen_scheduler.step(g_loss)
            dis_scheduler.step(d_loss)

        '''log losses per epoch'''
        print("epoch [{}/{}] done | Disc loss: {:.6f}\tGen loss: {:.6f}\tClassification loss: {:.6f}".format(ep+1, args.num_epoch, Dloss, Gloss, Closs), flush=True)
        if args.use_tensorboard:
            writer.add_text("epoch loss", "epoch [{}/{}] done | Disc loss: {:.6f}\tGen loss: {:.6f}\tClassification loss: {:.6f}".format(ep+1, args.num_epoch, Dloss, Gloss, Closs), ep+1)


        '''save models and generated images on fixed labels'''
        if ep % args.save_model_every == (args.save_model_every - 1):
            savedir = os.path.join(args.model_path, "{}/aux/".format(args.dataset))
            if not os.path.exists(savedir):
                os.makedirs(savedir, exist_ok=True)
            gen_pth = "gen_epoch_{}.model".format(ep+1)
            dis_pth = "dis_epoch_{}.model".format(ep+1)

            torch.save(gen.state_dict(), os.path.join(savedir, gen_pth))
            torch.save(dis.state_dict(), os.path.join(savedir, dis_pth))
            print("saved model to {}".format(savedir))

            '''save and add images based on fixed noise and labels'''
            imdir = os.path.join(args.output_path, "{}/aux/cgan_epoch_{}".format(args.dataset, ep+1))
            if not os.path.exists(imdir):
                os.makedirs(imdir, exist_ok=True)
            img = gen(fixed_noise, fixed_label).detach().cpu()
            grid = vutils.make_grid(img, normalize=True)
            if args.use_tensorboard:
                writer.add_image('epoch {}'.format(ep+1), grid, ep+1)
            vutils.save_image(grid, os.path.join(imdir, "original.png"))

            '''save images based on fixed noise and labels, make attribute 1'''
            for i in range(fsize):
                att = fixed_label
                att[:, i] = 1
                img = gen(fixed_noise, att).detach().cpu()
                grid = vutils.make_grid(img, normalize=True)
                vutils.save_image(grid, os.path.join(imdir, "{}.png".format(attnames[i].replace("/", ""))))
            print("saved original and transformed images in {}".format(imdir), flush=True)

    print("end training")
    if args.use_tensorboard:
        writer.add_text("Text", "end training")
        writer.close()
Exemplo n.º 16
0
def train(args):

    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()])

    # dataset and dataloader for training

    # train_dataset = DeepFashion(args.root_dir, args.img_dir, args.ann_dir, transform=transform)
    train_dataset = CelebA(args.root_dir,
                           args.img_dir,
                           args.ann_dir,
                           transform=transform)
    # test_dataset = DeepFashion(args.root_dir, args.img_dir, args.ann_dir, mode='Val', transform=transform)
    test_dataset = CelebA(args.root_dir,
                          args.img_dir,
                          args.ann_dir,
                          mode='Val',
                          transform=transform)
    # trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    trainloader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn)
    # testloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    testloader = DataLoader(test_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            collate_fn=collate_fn)

    # model, optimizer, criterion
    model = AttrEncoder(outdims=args.attrnum).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    criterion = nn.BCELoss().to(device)

    print("begin training", flush=True)
    for ep in range(args.num_epoch):
        model.train()
        for it, sample in enumerate(trainloader):
            im = sample['image'].to(device)
            t = sample['attributes'].to(device)

            optimizer.zero_grad()

            out = model.forward(im)
            loss = criterion(out, t)
            loss.backward()
            optimizer.step()
            if it % 10 == 9:
                print("{}th iter \t loss: {}".format(it + 1, loss), flush=True)

        cnt, allcnt = eval(model, testloader, device)

        print("-" * 30)
        print("epoch [{}/{}] done | accuracy [{}/{}]".format(
            ep, args.num_epoch, cnt, allcnt))
        print("-" * 30, flush=True)
        torch.save(model.state_dict(), "../out/epoch_{:1d}.model".format(ep))
    print("end training")
Exemplo n.º 17
0
if __name__ == '__main__':
    # this enables a Ctrl-C without triggering errors
    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
    print("hello")
    # parse arguments
    #args = parse_args()
    args = "profile/celeba.json"
    # initialize logging
    util.init_output_logging()

    # load hyper-parameters
    # hps = util.load_profile(args.profile)
    hps = util.load_profile(args)
    util.manual_seed(hps.ablation.seed)

    # build graph
    builder = Builder(hps)
    state = builder.build()

    # load dataset
    dataset = CelebA(root=hps.dataset.root,
                     transform=transforms.Compose((
                         transforms.CenterCrop(160),
                         transforms.Resize(64),
                         transforms.ToTensor()
                     )))

    # start training
    trainer = Trainer(hps=hps, dataset=dataset, **state)
    trainer.train()
Exemplo n.º 18
0
def train(args):
    if args.use_tensorboard:
        log_name = "enc_lr{}".format(args.learning_rate)
        writer = SummaryWriter(log_dir=os.path.join('runs', log_name))

    device = torch.device('cuda' if args.enable_cuda and torch.cuda.is_available() else 'cpu')

    # transforms applied
    transform = transforms.Compose([
                        transforms.Resize((args.image_size, args.image_size)),
                        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # dataset and dataloader for training
    train_dataset = CelebA(args.root_dir, args.img_dir, args.ann_dir, transform=transform)
    test_dataset = CelebA(args.root_dir, args.img_dir, args.ann_dir, transform=transform, train=False)
    fsize = train_dataset.feature_size
    trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=4)
    testloader = DataLoader(test_dataset, batch_size=args.show_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=4)

    '''
    dataloader returns dictionaries.
    sample : {'image' : (bs, 64, 64, 3), 'attributes' : (bs, att_size)}
    '''

    attnames = list(train_dataset.df.columns)[1:]

    # model, optimizer, criterion
    gen = Generator(in_c = args.nz + fsize)
    gen = gen.to(device)
    modeldir = os.path.join(args.model_path, "{}/res_{}".format(args.dataset, args.residual))
    MODELPATH = os.path.join(modeldir, 'gen_epoch_{}.ckpt'.format(args.model_ep))
    gen.load_state_dict(torch.load(MODELPATH))

    enc_y = Encoder(fsize).to(device)
    enc_z = Encoder(args.nz, for_y=False).to(device)

    """
    attr = AttrEncoder().to(device)
    ATTRPATH = os.path.join(args.model_path, "atteval_epoch_{}.ckpt".format(args.attr_epoch))
    attr.load_state_dict(torch.load(ATTRPATH))
    attr.eval()
    """

    # initialize weights for encoders
    enc_y.apply(init_weights)
    enc_z.apply(init_weights)

    if args.opt_method == 'Adam':
        enc_y_optim = optim.Adam(enc_y.parameters(), lr=args.learning_rate, betas=args.betas)
        enc_z_optim = optim.Adam(enc_z.parameters(), lr=args.learning_rate, betas=args.betas)
    if args.opt_method == 'SGD':
        enc_y_optim = optim.SGD(enc_y.parameters(), lr=args.learning_rate, momentum=args.momentum)
        enc_z_optim = optim.SGD(enc_z.parameters(), lr=args.learning_rate, momentum=args.momentum)
        enc_y_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_y_optim, patience=args.patience)
        enc_z_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_z_optim, patience=args.patience)
    criterion = nn.MSELoss()

    noise = torch.randn((args.batch_size, args.nz)).to(device)

    if args.use_tensorboard:
        writer.add_text("Text", "begin training, lr={}".format(args.learning_rate))
    print("begin training, lr={}".format(args.learning_rate), flush=True)
    stepcnt = 0
    gen.eval()
    enc_y.train()
    enc_z.train()

    for ep in range(args.num_epoch):

        YLoss = 0
        ZLoss = 0

        ittime = time.time()
        run_time = 0
        run_ittime = 0

        for it, sample in enumerate(trainloader):

            elapsed = time.time()

            x = sample['image'].to(device)
            y = sample['attributes'].to(device)

            '''training of attribute encoder'''
            # train on real images, target are real attributes
            enc_y.zero_grad()
            out = enc_y(x)
            l2loss = criterion(out, y)
            l2loss.backward()
            enc_y_optim.step()
            loss_y = l2loss.detach().cpu().item()

            '''training of identity encoder'''
            # train on fake images generated with real labels, target are original identities
            enc_z.zero_grad()
            y_sample = randomsample(train_dataset, args.batch_size).to(device)
            with torch.no_grad():
                x_fake = gen(noise, y_sample)
            z_recon = enc_z(x_fake)
            l2loss2 = criterion(z_recon, noise)
            l2loss2.backward()
            enc_z_optim.step()
            loss_z = l2loss2.detach().cpu().item()

            YLoss += loss_y
            ZLoss += loss_z

            ittime_a = time.time() - ittime
            run_time += time.time() - elapsed
            run_ittime += ittime_a

            '''log the losses and images, get time of loop'''
            if it % args.log_every == (args.log_every - 1):
                if args.use_tensorboard:
                    writer.add_scalar('y loss', loss_y, stepcnt+1)
                    writer.add_scalar('z loss', loss_z, stepcnt+1)

                after = time.time()
                print("{}th iter\ty loss: {:.5f}\tz loss: {:.5f}\t{:.4f}s per step, {:.4f}s per iter".format(it+1, loss_y, loss_z, run_time / args.log_every, run_ittime / args.log_every), flush=True)
                run_time = 0
                run_ittime = 0
            ittime = time.time()

            stepcnt += 1

        print("epoch [{}/{}] done | y loss: {:.6f} \t z loss: {:.6f}]".format(ep+1, args.num_epoch, YLoss, ZLoss), flush=True)
        if args.use_tensorboard:
            writer.add_text("epoch loss", "epoch [{}/{}] done | y loss: {:.6f} \t z loss: {:.6f}]".format(ep+1, args.num_epoch, YLoss, ZLoss), ep+1)

        savepath = os.path.join(args.model_path, "{}/res_{}".format(args.dataset, args.residual))
        try:
            torch.save(enc_y.state_dict(), os.path.join(savepath, "enc_y_epoch_{}.ckpt".format(ep+1)))
            torch.save(enc_z.state_dict(), os.path.join(savepath, "enc_z_epoch_{}.ckpt".format(ep+1)))
            print("saved encoder model at {}".format(savepath))
        except OSError:
            print("failed to save model for epoch {}".format(ep+1))

        if ep % args.recon_every == (args.recon_every - 1):
            # reconstruction and attribute transfer of images
            outpath = os.path.join(args.output_path, "{}/res_{}".format(args.dataset, args.residual))
            SAVEPATH = os.path.join(outpath, 'enc_epoch_{}'.format(ep+1))
            if not os.path.exists(SAVEPATH):
                os.mkdir(SAVEPATH)

            with torch.no_grad():
                for sample in testloader:
                    im = sample['image']
                    grid = vutils.make_grid(im, normalize=True)
                    vutils.save_image(grid, os.path.join(SAVEPATH, 'original.png'))
                    im = im.to(device)
                    y = enc_y(im)
                    z = enc_z(im)
                    im = gen(z, y)
                    """
                    y_h = attr(im)
                    """
                    recon = im.cpu()
                    grid = vutils.make_grid(recon, normalize=True)
                    vutils.save_image(grid, os.path.join(SAVEPATH, 'recon.png'))
                    break

                    """
                    CNT = 0
                    ALLCNT = 0
                    for idx in range(fsize):
                        fname = attnames[idx]
                        y_p_h = y_h.clone()
                        for i in range(args.show_size):
                            y_p_h[i, idx] = 0 if y_p_h[i, idx] == 1 else 1
                        out = gen(z, y_p_h)
                        trans = out.cpu()
                        grid2 = vutils.make_grid(trans, normalize=True)
                        vutils.save_image(grid2, os.path.join(SAVEPATH, '{}.png'.format(fname)))
                        cnt, allcnt = eval_im(attr, out, y_p_h)
                        CNT += cnt
                        ALLCNT += allcnt
                    break
            print("epoch {} for encoder, acc: {:.03}%".format(ep+1, CNT / ALLCNT * 100), flush=True)
                    """



    print("end training")
    if args.use_tensorboard:
        writer.add_text("Text", "end training")
        writer.close()