Пример #1
0
def iteration(logger, train, validation, model, optimizer, criterion, tracker):
    np.random.shuffle(train)
    batches = dlutils.batch_provider(train, 128, process)

    model.train()
    for x, y in batches:
        y_pred = model(x)

        loss = criterion(y_pred, y)

        tracker.update(dict(train_loss=torch.sqrt(loss)))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        batches = dlutils.batch_provider(validation, 1024, process)

        for x, y in batches:
            y_pred = model(x)

            loss = criterion(y_pred, y)

            tracker.update(dict(validation_loss=torch.sqrt(loss)))
Пример #2
0
def inference(model, data):
    model.eval()
    pred = []
    with torch.no_grad():
        batches = dlutils.batch_provider(data, 1024, process)

        for x, y in batches:
            y_pred = model(x)
            pred.append(y_pred)

    pred = torch.cat(pred, dim=0)
    return pred
Пример #3
0
def make_dataloader(dataset, batch_size, device):
    class BatchCollator(object):
        def __init__(self, device):
            self.device = device

        def __call__(self, batch):
            with torch.no_grad():
                y, x = batch
                x = torch.tensor(x / 255.0,
                                 requires_grad=True,
                                 dtype=torch.float32,
                                 device=self.device)
                y = torch.tensor(y, dtype=torch.int32, device=self.device)
                return y, x

    data_loader = dlutils.batch_provider(dataset, batch_size,
                                         BatchCollator(device))
    return data_loader
Пример #4
0
def main():
    batch_size = 128
    z_size = 512
    vae = VAE(zsize=z_size, layer_count=5)
    vae.cuda()
    vae.train()
    vae.weight_init(mean=0, std=0.02)

    lr = 0.0005

    vae_optimizer = optim.Adam(vae.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
 
    train_epoch = 40

    sample1 = torch.randn(128, z_size).view(-1, z_size, 1, 1)

    for epoch in range(train_epoch):
        vae.train()

        with open('data_fold_%d.pkl' % (epoch % 5), 'rb') as pkl:
            data_train = pickle.load(pkl)

        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train, batch_size, process_batch, report_progress=True)

        rec_loss = 0
        kl_loss = 0

        epoch_start_time = time.time()

        if (epoch + 1) % 8 == 0:
            vae_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        i = 0
        for x in batches:
            vae.train()
            vae.zero_grad()
            rec,var, mu, logvar = vae(x)
            loss_re = loss_function(rec, var,x, mu, logvar)
            loss_re.backward()
            vae_optimizer.step()
            rec_loss += loss_re.item()


            #############################################

            os.makedirs('results_rec', exist_ok=True)
            os.makedirs('results_gen', exist_ok=True)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = 60
            i += 1
            if i % m == 0:
                rec_loss /= m
                kl_loss /= m
                print('\n[%d/%d] - ptime: %.2f, rec loss: %.9f, KL loss: %.9f' % (
                    (epoch + 1), train_epoch, per_epoch_ptime, rec_loss, kl_loss))
                rec_loss = 0
                kl_loss = 0
                with torch.no_grad():
                    vae.eval()
                    x_rec,x_var, _, _ = vae(x)
                    x_var=(x_var**0.50)*0.25*3
                    resultsample = torch.cat([x, x_rec]) * 0.5 + 0.5
                    resultsample=torch.cat([resultsample,x_var])
                    resultsample = resultsample.cpu()
                    save_image(resultsample.view(-1, 3, im_size, im_size),
                               'results_rec/QR/sample_' + str(epoch) + "_" + str(i) + '.png')
                    x_rec,x_var = vae.decode(sample1)
                    resultsample = x_rec * 0.5 + 0.5
                    resultsample = resultsample.cpu()
                    save_image(resultsample.view(-1, 3, im_size, im_size),
                               'results_gen/QR/sample_' + str(epoch) + "_" + str(i) + '.png')

        del batches
        del data_train
    print("Training finish!... save training results")
    torch.save(vae.state_dict(), "VAEmodel.pkl")
Пример #5
0
#  Training
# ----------
def process_batch(batch):
    x = torch.from_numpy(np.asarray(batch, dtype=np.float32)).cuda()
    # x = torch.from_numpy(np.asarray(batch, dtype=np.float32) / 255.)
    x = x.view(-1, 1, opt.img_size, opt.img_size)

    return x


for epoch in range(opt.n_epochs):
    i = 0
    with open('./data_fold_train_128.pkl', 'rb') as pkl:
        data_train = pickle.load(pkl)
    batches = batch_provider(data_train,
                             opt.batch_size,
                             process_batch,
                             report_progress=True)
    for imgs in batches:
        i = i + 1

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0),
                         requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0),
                        requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
Пример #6
0
def main():
    batch_size = 60
    #z_size = 512
    z_size = 100
    ae = AE(zsize=z_size, layer_count=5,channels=1)
    #vae=nn.DataParallel(vae)
    ae.cuda()
    ae.train()
    ae.weight_init(mean=0, std=0.02)

    lr = 0.0005

    ae_optimizer = optim.Adam(ae.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
 
    train_epoch =1000


    for epoch in range(train_epoch):
        ae.train()
        #tmp= epoch % 5
        with open('../vae_gan_brain/data_fold_train_128.pkl', 'rb') as pkl:
            data_train = pickle.load(pkl)




        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train, batch_size, process_batch, report_progress=True)

        rec_loss = 0


        epoch_start_time = time.time()

        if (epoch + 1) % 16 == 0:
            ae_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        i = 0
        for x in batches:
            i=i+1
            ae.train()
            ae.zero_grad()
            rec = ae(x)

            loss_re = loss_fn (rec, x)
            (loss_re).backward()
            ae_optimizer.step()
            rec_loss += loss_re.item()



            #############################################
            os.makedirs('results_ori', exist_ok=True)
            os.makedirs('results_rec', exist_ok=True)


            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time


            if epoch%20==0 and i == 5:
            #if epoch  == 0 and i==1:
                rec_loss /= i
                #kl_loss /= i
                print('\n[%d/%d] - ptime: %.2f, rec loss: %.9f' % (
                    (epoch + 1), train_epoch, per_epoch_ptime, rec_loss))
                rec_loss = 0
                with torch.no_grad():
                    ae.eval()
                    x_rec= ae(x)

                    x=x.cpu()

                    x_rec=x_rec.cpu()


                    # save_image(resultsample.view(-1, 3, im_size, im_size),
                    #            'results_rec/sample_' + str(epoch) + "_" + str(i) + '.png')
                    for j in range(20,29):
                        org_img = transforms.ToPILImage()(x[j].squeeze(0)).convert('L')
                        rec_img = transforms.ToPILImage()(x_rec[j].squeeze(0)).convert('L')

                        org_img.save('results_ori/ori_' + str(epoch) + "_" + str(i) +"_"+str(j)+ '.png')
                        rec_img.save('results_rec/rec_' + str(epoch) + "_" + str(i) + "_"+str(j)+ '.png')


                    # resultsample = x_rec * 0.5 + 0.5
                    # resultsample = resultsample.cpu()
                    # save_image(resultsample.view(-1, 3, im_size, im_size),
                    #            'results_gen/sample_' + str(epoch) + "_" + str(i) + '.png')


        del batches
        del data_train
    print("Training finish!... save training results")
    # output_latent_space = open('./latent_space.pkl', 'wb')
    # # output_bce = open('./BCE_loss.pkl', 'wb')
    # pickle.dump(noise_list, output_latent_space)
    # # pickle.dump(BCE_list, output_bce)
    # output_latent_space.close()
    # # output_bce.close()
    torch.save(ae.state_dict(), "AEmodel.pkl")
Пример #7
0
def main():
    batch_size = 60
    #z_size = 512
    z_size = 100
    vae = VAE(zsize=z_size, layer_count=5, channels=1)
    #vae=nn.DataParallel(vae)
    vae.cuda()
    vae.train()
    vae.weight_init(mean=0, std=0.02)

    lr = 0.0005

    vae_optimizer = optim.Adam(vae.parameters(),
                               lr=lr,
                               betas=(0.5, 0.999),
                               weight_decay=1e-5)

    train_epoch = 1000

    sample1 = torch.randn(batch_size, z_size).view(-1, z_size, 1, 1)
    BCE_list = []
    KLD_list = []
    noise_list = []

    for epoch in range(train_epoch):
        vae.train()
        #tmp= epoch % 5
        with open('../vae_gan_brain/data_fold_train_128.pkl', 'rb') as pkl:
            data_train = pickle.load(pkl)

        # with open('./data_fold_train%d.pkl' % ( (tmp+1) % 5), 'rb') as pkl:
        #     data_train.extend( pickle.load(pkl))
        # with open('./data_fold_train%d.pkl' % ( (tmp+2) % 5), 'rb') as pkl:
        #     data_train.extend( pickle.load(pkl))

        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train,
                                 batch_size,
                                 process_batch,
                                 report_progress=True)

        rec_loss = 0
        kl_loss = 0

        epoch_start_time = time.time()

        if (epoch + 1) % 8 == 0:
            vae_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        i = 0
        for x in batches:
            i = i + 1
            vae.train()
            vae.zero_grad()
            rec, mu, logvar, latent_space = vae(x)

            loss_re, loss_kl = loss_function(rec, x, mu, logvar)
            (loss_re + loss_kl).backward()
            vae_optimizer.step()
            rec_loss += loss_re.item()
            kl_loss += loss_kl.item()

            #############################################
            os.makedirs('results_ori', exist_ok=True)
            os.makedirs('results_rec', exist_ok=True)
            os.makedirs('results_gen', exist_ok=True)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            if epoch > 750:
                noise_list.append(latent_space)

            if epoch % 20 == 0 and i == 5:
                #if epoch  == 0 and i==1:
                rec_loss /= i
                kl_loss /= i
                print(
                    '\n[%d/%d] - ptime: %.2f, rec loss: %.9f, KL loss: %.9f' %
                    ((epoch + 1), train_epoch, per_epoch_ptime, rec_loss,
                     kl_loss))
                rec_loss = 0
                kl_loss = 0
                with torch.no_grad():
                    vae.eval()
                    x_rec, _, _, z = vae(x)
                    x_gen = vae.decode(sample1)
                    x = x.cpu()
                    x_gen = x_gen.cpu()
                    x_rec = x_rec.cpu()

                    # save_image(resultsample.view(-1, 3, im_size, im_size),
                    #            'results_rec/sample_' + str(epoch) + "_" + str(i) + '.png')
                    for j in range(20, 29):
                        org_img = transforms.ToPILImage()(
                            x[j].squeeze(0)).convert('L')
                        rec_img = transforms.ToPILImage()(
                            x_rec[j].squeeze(0)).convert('L')
                        gen_img = transforms.ToPILImage()(
                            x_gen[j].squeeze(0)).convert('L')
                        org_img.save('results_ori/ori_' + str(epoch) + "_" +
                                     str(i) + "_" + str(j) + '.png')
                        rec_img.save('results_rec/rec_' + str(epoch) + "_" +
                                     str(i) + "_" + str(j) + '.png')
                        gen_img.save('results_gen/gen_' + str(epoch) + "_" +
                                     str(i) + "_" + str(j) + '.png')

                    # resultsample = x_rec * 0.5 + 0.5
                    # resultsample = resultsample.cpu()
                    # save_image(resultsample.view(-1, 3, im_size, im_size),
                    #            'results_gen/sample_' + str(epoch) + "_" + str(i) + '.png')

        del batches
        del data_train
    print("Training finish!... save training results")
    output_latent_space = open('./latent_space.pkl', 'wb')
    # output_bce = open('./BCE_loss.pkl', 'wb')
    pickle.dump(noise_list, output_latent_space)
    # pickle.dump(BCE_list, output_bce)
    output_latent_space.close()
    # output_bce.close()
    torch.save(vae.state_dict(), "VAEmodel.pkl")
Пример #8
0
def main():
    input_channels = 1
    hidden_size = 128
    max_epochs = 500
    lr = 3e-4

    beta = 20
    alpha = 0.2
    gamma = 30
    batch_size = 60

    G = VAE_GAN_Generator(input_channels, hidden_size).cuda()
    D = Discriminator(input_channels).cuda()
    G.apply(weights_init)
    D.apply(weights_init)
    criterion = nn.BCELoss()
    criterion.cuda()

    opt_enc = optim.RMSprop(G.encoder.parameters(), lr=lr, alpha=0.9)
    opt_dec = optim.RMSprop(G.decoder.parameters(), lr=lr, alpha=0.9)
    opt_dis = optim.RMSprop(D.parameters(), lr=lr * alpha, alpha=0.9)
    #opt_dis = optim.RMSprop(D.parameters(), lr=lr )
    fixed_noise = Variable(torch.randn(batch_size, hidden_size)).cuda()


    for epoch in range(max_epochs):
        G.train()
        D.train()

        #tmp= epoch % 5
        with open('../vae_gan_brain/data_fold_train_128.pkl', 'rb') as pkl:
            data_train = pickle.load(pkl)


        #data_train=data_train[0:13376]
        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train, batch_size, process_batch, report_progress=True)

        D_real_list, D_rec_enc_list, D_rec_noise_list, D_list = [], [], [], []
        g_loss_list, rec_loss_list, prior_loss_list = [], [], []

        epoch_start_time = time.time()



        i = 0
        for x in batches:
            # ones_label = torch.ones(batch_size).cuda()
            # zeros_label = torch.zeros(batch_size).cuda()
            ones_label =  Variable(torch.ones(batch_size)).cuda()
            zeros_label =  Variable(torch.zeros(batch_size)).cuda()

            datav = Variable(x).cuda()
            mean, logvar, rec_enc = G(datav)

            noisev = Variable(torch.randn(batch_size, hidden_size)).cuda()
            rec_noise = G.decoder(noisev)
            #
            # ======== Train Discriminator ======== #

            frozen_params(G)
            free_params(D)
            #

            # train discriminator
            output = D(datav)
            output=output.squeeze(1)
            errD_real = criterion(output, ones_label)
            D_real_list.append(output.data.mean())
            output = D(rec_enc)
            output=output.squeeze(1)
            errD_rec_enc = criterion(output, zeros_label)
            D_rec_enc_list.append(output.data.mean())
            output = D(rec_noise)
            output=output.squeeze(1)
            errD_rec_noise = criterion(output, zeros_label)
            D_rec_noise_list.append(output.data.mean())

            dis_img_loss = errD_real + errD_rec_enc + errD_rec_noise
            #dis_img_loss =  errD_real + errD_rec_enc
           # print ("print (dis_img_loss)", dis_img_loss)
            D_list.append(dis_img_loss.data.mean())
            opt_dis.zero_grad()
            dis_img_loss.backward(retain_graph=True)
            opt_dis.step()
                    # ======== Train Generator ======== #

            free_params(G)
            frozen_params(D)

            # train decoder
            output = D(datav)
            output=output.squeeze(1)
            errD_real = criterion(output, ones_label)
            output = D(rec_enc)
            output=output.squeeze(1)
            errD_rec_enc = criterion(output, zeros_label)
            output = D(rec_noise)
            output=output.squeeze(1)
            errD_rec_noise = criterion(output, zeros_label)

            similarity_rec_enc = D.similarity(rec_enc)
            similarity_data = D.similarity(datav)

            dis_img_loss = errD_real + errD_rec_enc + errD_rec_noise
            #dis_img_loss = errD_real + errD_rec_enc
            #print ("dis_img_loss",dis_img_loss)
            #gen_img_loss = - dis_img_loss
            gen_img_loss = -dis_img_loss

            g_loss_list.append(gen_img_loss.data.mean())
            rec_loss = ((similarity_rec_enc - similarity_data) ** 2).mean()
            rec_loss_list.append(rec_loss.data.mean())
            err_dec = gamma * rec_loss + gen_img_loss
            #print("err_dec",err_dec)
            opt_dec.zero_grad()
            err_dec.backward(retain_graph=True)
            opt_dec.step()

            # train encoder
            prior_loss = 1 + logvar - mean.pow(2) - logvar.exp()
            prior_loss = (-0.5 * torch.sum(prior_loss)) / torch.numel(mean.data)
            #print (prior_loss, mean, std)
            prior_loss_list.append(prior_loss.data.mean())
            err_enc = prior_loss + beta * rec_loss

            opt_enc.zero_grad()
            err_enc.backward()
            opt_enc.step()




            #############################################
            os.makedirs('results_ori', exist_ok=True)
            os.makedirs('results_rec', exist_ok=True)
            os.makedirs('results_gen', exist_ok=True)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = 6
            i += 1
            if epoch%5==0 and i % m == 0:
                print(
                    '[%d/%d]: D_real:%.4f, D_enc:%.4f, D_noise:%.4f, Loss_D:%.4f,Loss_G:%.4f, rec_loss:%.4f, prior_loss:%.4f'
 #                   '[%d/%d]: D_real:%.4f, D_enc:%.4f, Loss_D:%.4f, \\'

                    % (epoch,
                       max_epochs,
                       torch.mean(torch.tensor(D_real_list)),
                       torch.mean(torch.tensor(D_rec_enc_list)),
                       torch.mean(torch.tensor(D_rec_noise_list)),
                       torch.mean(torch.tensor(D_list)),
                       torch.mean(torch.tensor(g_loss_list)),
                       torch.mean(torch.tensor(rec_loss_list)),
                       torch.mean(torch.tensor(prior_loss_list))))

                with torch.no_grad():
                   D.eval()
                   G.eval()
                   _, _, x_rec = G.forward(x)
                   x_gen = G.decoder(fixed_noise)
                   x=x.cpu()
                   x_gen=x_gen.cpu()
                   x_rec=x_rec.cpu()


                   # save_image(resultsample.view(-1, 3, im_size, im_size),
                   #            'results_rec/sample_' + str(epoch) + "_" + str(i) + '.png')
                   for j in range(20,29):
                       org_img = transforms.ToPILImage()(x[j].squeeze(0)).convert('L')
                       rec_img = transforms.ToPILImage()(x_rec[j].squeeze(0)).convert('L')
                       gen_img = transforms.ToPILImage()(x_gen[j].squeeze(0)).convert('L')
                       org_img.save('results_ori/ori_' + str(epoch) + "_" + str(i) +"_"+str(j)+ '.png')
                       rec_img.save('results_rec/rec_' + str(epoch) + "_" + str(i) + "_"+str(j)+ '.png')
                       gen_img.save('results_gen/gen_' + str(epoch) + "_" + str(i) +"_"+str(j)+  '.png')

                    # resultsample = x_rec * 0.5 + 0.5
                    # resultsample = resultsample.cpu()
                    # save_image(resultsample.view(-1, 3, im_size, im_size),
                    #            'results_gen/sample_' + str(epoch) + "_" + str(i) + '.png')

        del batches
        del data_train
    print("Training finish!... save training results")
    torch.save(G.state_dict(), "G.pkl")
    torch.save(D.state_dict(), "D.pkl")
Пример #9
0
def main():
    batch_size = 128
    z_size = 512
    vae = VAE(zsize=z_size, layer_count=5)
    vae.train()
    vae.weight_init(mean=0, std=0.02)
    vae = nn.DataParallel(vae)
    vae.to(device)

    if args.resume is not None:
        vae.load_state_dict(torch.load(args.resume))

    vae_optimizer = optim.Adam(vae.parameters(),
                               lr=args.lr,
                               betas=(0.5, 0.999),
                               weight_decay=1e-5)

    train_epoch = 40

    sample1 = torch.randn(64, z_size).view(-1, z_size, 1, 1)

    folds = 2

    for epoch in range(train_epoch):
        vae.train()

        with open('bedroom128_splits/data_fold_%d.pkl' % (epoch % folds),
                  'rb') as pkl:
            data_train = pickle.load(pkl)

        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train,
                                 batch_size,
                                 process_batch,
                                 report_progress=True)

        rec_loss = 0
        kl_loss = 0

        epoch_start_time = time.time()

        # if (epoch + 1) % 2 == 0:
        #     vae_optimizer.param_groups[0]['lr'] /= 2
        #     print("learning rate change!")

        i = 0
        for x in batches:
            if x.shape[0] != batch_size:
                break
            vae.train()
            vae.zero_grad()
            rec, mu, logvar = vae(x)

            loss_re, loss_kl = loss_function(rec, x, mu, logvar)
            (loss_re + loss_kl).backward()
            vae_optimizer.step()
            rec_loss += loss_re.item()
            kl_loss += loss_kl.item()

            #############################################

            os.makedirs('results_rec', exist_ok=True)
            os.makedirs('results_gen', exist_ok=True)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = 100
            i += 1
            if i % m == 0:
                rec_loss /= m
                kl_loss /= m
                print(
                    '\n[%d/%d] - ptime: %.2f, rec loss: %.9f, KL loss: %.9f' %
                    ((epoch + 1), train_epoch, per_epoch_ptime, rec_loss,
                     kl_loss))
                rec_loss = 0
                kl_loss = 0
                with torch.no_grad():
                    vae.eval()
                    x_rec, _, _ = vae(x)
                    resultsample = torch.cat([x, x_rec]) * 0.5 + 0.5
                    resultsample = resultsample.cpu()
                    save_image(
                        resultsample.view(-1, 3, args.img_size, args.img_size),
                        'results_rec/sample_' + str(epoch) + "_" + str(i) +
                        '.png')

                    resultsample = resultsample.view(-1, 3, args.img_size,
                                                     args.img_size).numpy()
                    resultsample = (resultsample * 255).transpose(
                        [0, 2, 3, 1]).astype(np.uint8)
                    np.save(
                        'results_rec/sample_' + str(epoch) + "_" + str(i) +
                        '.npy', resultsample)

                    x_rec = vae.module.decode(sample1)
                    resultsample = x_rec * 0.5 + 0.5
                    resultsample = resultsample.cpu()
                    save_image(
                        resultsample.view(-1, 3, args.img_size, args.img_size),
                        'results_gen/sample_' + str(epoch) + "_" + str(i) +
                        '.png')

                    resultsample = resultsample.view(-1, 3, args.img_size,
                                                     args.img_size).numpy()
                    resultsample = (resultsample * 255).transpose(
                        [0, 2, 3, 1]).astype(np.uint8)
                    np.save(
                        'results_gen/sample_' + str(epoch) + "_" + str(i) +
                        '.npy', resultsample)
        del batches
        del data_train
        print("Training finish!... save training results")
        torch.save(vae.state_dict(), "VAEmodel_epoch{}.pkl".format(epoch + 1))
Пример #10
0
def main(args):
    if args.dataset == 'face':
        dataset_name = './data/'
    else:
        dataset_name = os.path.join('data', args.dataset + '_')
    output_root = os.path.join('checkpoints', args.exp_name)
    result_rec_pth = os.path.join(output_root, 'results_rec')
    result_gen_pth = os.path.join(output_root, 'results_gen')
    os.makedirs(output_root, exist_ok=True)
    os.makedirs(result_rec_pth, exist_ok=True)
    os.makedirs(result_gen_pth, exist_ok=True)

    batch_size = args.batch_size
    z_size = 512
    vae = VAE(zsize=z_size, layer_count=5)
    # vae = ResEncoder(ndf=16, latent_variable_size=512)
    vae.cuda()
    vae.train()
    vae.weight_init(mean=0, std=0.02)
    fold = 5
    lr = 0.0005

    vae_optimizer = optim.Adam(vae.parameters(),
                               lr=lr,
                               betas=(0.5, 0.999),
                               weight_decay=1e-5)

    train_epoch = args.num_epoch

    sample1 = torch.randn(batch_size, z_size).view(-1, z_size, 1, 1)

    rec_loss_draw = []
    kl_loss_draw = []
    x_idx = []
    for epoch in range(train_epoch):
        vae.train()

        with open(dataset_name + 'data_fold_%d.pkl' % (epoch % fold),
                  'rb') as pkl:
            data_train = pickle.load(pkl)

        print("Train set size:", len(data_train))

        random.shuffle(data_train)

        batches = batch_provider(data_train,
                                 batch_size,
                                 process_batch,
                                 report_progress=False)

        rec_loss = 0
        kl_loss = 0

        epoch_start_time = time.time()

        if (epoch + 1) % 8 == 0 and (epoch + 1) < 50:
            vae_optimizer.param_groups[0]['lr'] /= 10
            print("learning rate change!")

        i = 0

        for x in batches:
            vae.train()
            vae.zero_grad()
            rec, mu, logvar = vae(x)

            loss_re, loss_kl = loss_function(rec, x, mu, logvar)
            (loss_re + loss_kl).backward()
            vae_optimizer.step()
            rec_loss += loss_re.item()
            kl_loss += loss_kl.item()

            #############################################

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = 60
            i += 1
            if i % m == 0:
                rec_loss /= m
                kl_loss /= m
                print(
                    '\n[%d/%d] - ptime: %.2f, rec loss: %.9f, KL loss: %.9f' %
                    ((epoch + 1), train_epoch, per_epoch_ptime, rec_loss,
                     kl_loss))

                rec_loss_draw.append(rec_loss)
                kl_loss_draw.append(kl_loss)
                x_idx.append(
                    float('%.2f' % (epoch + (i /
                                             (len(data_train) / batch_size)))))
                print(x_idx)

                rec_loss = 0
                kl_loss = 0

                with torch.no_grad():
                    vae.eval()
                    x_rec, _, _ = vae(x)
                    resultsample = torch.cat([x, x_rec]) * 0.5 + 0.5
                    resultsample = resultsample.cpu()
                    save_image(
                        resultsample.view(-1, 3, im_size,
                                          im_size), result_rec_pth +
                        '/sample_' + str(epoch) + "_" + str(i) + '.png')

                    x_gen = vae.decode(sample1)
                    resultsample = x_gen * 0.5 + 0.5
                    resultsample = resultsample.cpu()
                    save_image(
                        resultsample.view(-1, 3, im_size,
                                          im_size), result_gen_pth +
                        '/sample_' + str(epoch) + "_" + str(i) + '.png')

        del batches
        del data_train
        # draw_loss(output_root, rec_loss_draw, kl_loss_draw, x_idx)

    print("Training finish!... save training results")
    torch.save(vae.state_dict(), os.path.join(output_root, "VAEmodel.pkl"))