コード例 #1
0
class AnoGAN:
    """AnoGAN Class
    """
    def __init__(self, opt):
        # super(AnoGAN, self).__init__(opt, dataloader)

        # Initalize variables.
        self.opt = opt

        self.niter = self.opt.niter
        self.start_iter = 0
        self.netd_niter = 5
        self.test_iter = 100
        self.lr = self.opt.lr
        self.batchsize = {'train': self.opt.batchsize, 'test': 1}

        self.pretrained = False

        self.phase = 'train'
        self.outf = self.opt.experiment_group
        self.algorithm = 'wgan'

        # LOAD DATA SET
        self.dataloader = {
            'train':
            provider('train',
                     opt.category,
                     batch_size=self.batchsize['train'],
                     num_workers=4),
            'test':
            provider('test',
                     opt.category,
                     batch_size=self.batchsize['test'],
                     num_workers=4)
        }

        self.trn_dir = os.path.join(self.outf, self.opt.experiment_name,
                                    'train')
        self.tst_dir = os.path.join(self.outf, self.opt.experiment_name,
                                    'test')

        self.test_img_dir = os.path.join(self.outf, self.opt.experiment_name,
                                         'test', 'images')
        if not os.path.isdir(self.test_img_dir):
            os.makedirs(self.test_img_dir)

        self.best_test_dir = os.path.join(self.outf, self.opt.experiment_name,
                                          'test', 'best_images')
        if not os.path.isdir(self.best_test_dir):
            os.makedirs(self.best_test_dir)

        self.weight_dir = os.path.join(self.trn_dir, 'weights')
        if not os.path.exists(self.weight_dir): os.makedirs(self.weight_dir)

        # -- Misc attributes
        self.epoch = 0

        self.l_con = l1_loss
        self.l_enc = l2_loss

        ##
        # Create and initialize networks.
        self.netg = NetG().cuda()
        self.netd = NetD().cuda()

        # Setup optimizer
        self.optimizer_d = optim.RMSprop(self.netd.parameters(), lr=self.lr)
        self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.lr)

        ##
        self.weight_path = os.path.join(self.outf, self.opt.experiment_name,
                                        'train', 'weights')
        if os.path.exists(self.weight_path) and len(
                os.listdir(self.weight_path)) == 2:
            print("Loading pre-trained networks...\n")
            self.netg.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netD.pth'))['state_dict'])

            self.optimizer_g.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netG.pth'))['optimizer'])
            self.optimizer_d.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netD.pth'))['optimizer'])

            self.start_iter = torch.load(
                os.path.join(self.weight_path, 'netG.pth'))['epoch']

    ##
    def start(self):
        """ Train the model
        """

        ##
        # TRAIN
        # self.total_steps = 0
        best_criterion = -1  #float('inf')
        best_auc = -1

        # Train for niter epochs.
        # print(">> Training model %s." % self.name)
        for self.epoch in range(self.start_iter, self.niter):
            # Train for one epoch
            mean_wass = self.train()

            (auc, res, best_rec, best_threshold), res_total = self.test()
            message = ''
            # message += 'criterion: (%.3f+%.3f)/2=%.3f ' % (best_rec[0], best_rec[1], res)
            # message += 'best threshold: %.3f ' % best_threshold
            message += 'Wasserstein Distance:%.3d ' % mean_wass
            message += 'AUC: %.3f ' % auc

            print(message)

            torch.save(
                {
                    'epoch': self.epoch + 1,
                    'state_dict': self.netg.state_dict(),
                    'optimizer': self.optimizer_g.state_dict()
                }, '%s/netG.pth' % (self.weight_dir))

            torch.save(
                {
                    'epoch': self.epoch + 1,
                    'state_dict': self.netd.state_dict(),
                    'optimizer': self.optimizer_d.state_dict()
                }, '%s/netD.pth' % (self.weight_dir))

            if auc > best_auc:
                best_auc = auc
                new_message = "******** New optimal found, saving state ********"
                message = message + '\n' + new_message
                print(new_message)

                for img in os.listdir(self.best_test_dir):
                    os.remove(os.path.join(self.best_test_dir, img))

                for img in os.listdir(self.test_img_dir):
                    shutil.copyfile(os.path.join(self.test_img_dir, img),
                                    os.path.join(self.best_test_dir, img))

                shutil.copyfile('%s/netG.pth' % (self.weight_dir),
                                '%s/netg_best.pth' % (self.weight_dir))

            log_name = os.path.join(self.outf, self.opt.experiment_name,
                                    'loss_log.txt')
            message = 'Epoch%3d:' % self.epoch + ' ' + message
            with open(log_name, "a") as log_file:
                if self.epoch == 0:
                    log_file.write('\n\n')
                log_file.write('%s\n' % message)

        print(">> Training %s Done..." % self.opt.experiment_name)

    ##
    def train(self):
        """ Train the model for one epoch.
        """
        print("\n>>> Epoch %d/%d, Running " % (self.epoch + 1, self.niter) +
              self.opt.experiment_name)

        self.netg.train()
        self.netd.train()
        # for p in self.netg.parameters(): p.requires_grad = True

        mean_wass = 0

        tk0 = tqdm(self.dataloader['train'],
                   total=len(self.dataloader['train']))
        for i, itr in enumerate(tk0):
            input, _ = itr
            input = input.cuda()
            wasserstein_d = None
            # if self.algorithm == 'wgan':
            # train NetD
            for _ in range(self.netd_niter):
                # for p in self.netd.parameters(): p.requires_grad = True
                self.optimizer_d.zero_grad()

                # forward_g
                latent_i = torch.rand(self.batchsize['train'], 64, 1, 1).cuda()
                fake = self.netg(latent_i)
                # forward_d
                _, pred_real = self.netd(input)
                _, pred_fake = self.netd(fake)  # .detach() TODO

                # Backward-pass
                wasserstein_d = (pred_fake.mean() - pred_real.mean()) * 1
                wasserstein_d.backward()
                self.optimizer_d.step()

                for p in self.netd.parameters():
                    p.data.clamp_(-0.01, 0.01)  #<<<<<<<

            # train netg
            # for p in self.netd.parameters(): p.requires_grad = False
            self.optimizer_g.zero_grad()
            noise = torch.rand(self.batchsize['train'], 64, 1, 1).cuda()
            fake = self.netg(noise)
            _, pred_fake = self.netd(fake)
            err_g_d = -pred_fake.mean()  # negative

            err_g_d.backward()
            self.optimizer_g.step()

            errors = {
                'loss_netD': wasserstein_d.item(),
                'loss_netG': round(err_g_d.item(), 3),
            }

            mean_wass += wasserstein_d.item()
            tk0.set_postfix(errors)

            if i % 50 == 0:
                img_dir = os.path.join(self.outf, self.opt.experiment_name,
                                       'train', 'images')
                if not os.path.isdir(img_dir):
                    os.makedirs(img_dir)
                self.save_image_cv2(input.data, '%s/reals.png' % img_dir)
                self.save_image_cv2(fake.data,
                                    '%s/fakes%03d.png' % (img_dir, i))

        mean_wass /= len(self.dataloader['train'])
        return mean_wass

    ##
    def test(self):
        """ Test AnoGAN model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        self.netg.eval()
        self.netd.eval()
        # for p in self.netg.parameters(): p.requires_grad = False
        # for p in self.netd.parameters(): p.requires_grad = False

        for img in os.listdir(self.test_img_dir):
            os.remove(os.path.join(self.test_img_dir, img))

        self.phase = 'test'
        meter = Meter_AnoGAN()
        tk1 = tqdm(self.dataloader['test'], total=len(self.dataloader['test']))
        for i, itr in enumerate(tk1):
            input, target = itr
            input = input.cuda()

            latent_i = torch.rand(self.batchsize['test'], 64, 1, 1).cuda()
            latent_i.requires_grad = True

            optimizer_latent = optim.Adam([latent_i], lr=self.lr)
            test_loss = None
            for _ in range(self.test_iter):
                optimizer_latent.zero_grad()
                fake = self.netg(latent_i)
                residual_loss = self.l_con(input, fake)
                latent_o, _ = self.netd(fake)
                discrimination_loss = self.l_enc(latent_i, latent_o)
                alpha = 0.1
                test_loss = (
                    1 - alpha) * residual_loss + alpha * discrimination_loss
                test_loss.backward()
                optimizer_latent.step()

            abnormal_score = test_loss
            meter.update(abnormal_score, target)  #<<<TODO

            # Save test images.
            combine = torch.cat([input.cpu(), fake.cpu()], dim=0)
            self.save_image_cv2(combine,
                                '%s/%05d.jpg' % (self.test_img_dir, i + 1))

        criterion, res_total = meter.get_metrics()

        # rename images
        for i, res in enumerate(res_total):
            os.rename('%s/%05d.jpg' % (self.test_img_dir, i + 1),
                      '%s/%05d_%s.jpg' % (self.test_img_dir, i + 1, res))

        return criterion, res_total

    @staticmethod
    def save_image_cv2(tensor, filename):
        # return
        from torchvision.utils import make_grid
        # tensor = (tensor + 1) / 2
        grid = make_grid(tensor, 8, 2, 0, False, None, False)
        ndarray = grid.mul_(255).clamp_(0, 255).permute(1, 2, 0).to(
            'cpu', torch.uint8).numpy()
        cv2.imwrite(filename, ndarray)
コード例 #2
0
         netg.train()
     fake = (fake + 1) / 2 * 255
     real = (ass_label + 1) / 2 * 255
     ori = (img + 1) / 2 * 255
     al = th.cat((fake, real, ori), 2)
     display = make_grid(al, 20).cpu().numpy()
     if win1 is None:
         win1 = vis.image(display,
                          opts=dict(title="train", caption='train'))
     else:
         vis.image(display, win=win1)
 if iteration % 500 == 0:
     state = {
         'netA': neta.state_dict(),
         'netG': netg.state_dict(),
         'netD': netd.state_dict()
     }
     th.save(state, './snapshot_%d.t7' % iteration)
     print('iter = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
         iteration, lossG/2, lossA/3, lossD/3
     ))
 if iteration % 20 == 0:
     if win is None:
         win = vis.line(X=np.array([[iteration, iteration,
                                     iteration]]),
                        Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                        opts=dict(
                            title='GaitGAN',
                            ylabel='loss',
                            xlabel='iterations',
                            legend=['lossG', 'lossA', 'lossD']
コード例 #3
0
ファイル: main.py プロジェクト: Jarlonyan/pytorch_learning
def train():
    dataset = torchvision.datasets.ImageFolder(conf.data_path,
                                               transform=transforms)
    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=conf.batch_size,
                                             shuffle=True,
                                             drop_last=True)
    netG = NetG(conf.ngf, conf.nz)
    netD = NetD(conf.ndf)

    criterion = nn.BCELoss()
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=conf.lr,
                                  betas=(conf.beta1, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=conf.lr,
                                  betas=(conf.beta1, 0.999))

    label = torch.FloatTensor(conf.batch_size)
    real_label = 1
    fake_label = 0

    for epoch in range(1, conf.epoch + 1):
        for i, (imgs, _) in enumerate(dataloader):
            #step1:固定G,训练D
            optimizerD.zero_grad()
            output = netD(imgs)  #让D尽可能把真图片识别为1
            label.data.fill_(real_label)
            errD_real = criterion(output, label)
            errD_real.backward()
            #让D尽可能把假图判别为0
            label.data.fill_(fake_label)
            noise = torch.randn(conf.batch_size, conf.nz, 1, 1)
            fake = netG(noise)  #生成假图
            output = netD(fake.detach())  #避免梯度传到G,因为G不用更新
            errD_fake = criterion(output, label)
            errD_fake.backward()
            errD = errD_fake + errD_real
            optimizerD.step()

            #step2:固定判别器D,训练生成器G
            optimizerG.zero_grad()
            label.data.fill_(real_label)  #让D尽可能把G生成的假图判别为1
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            optimizerG.step()

            if i % 4 == 0:
                rate = i * 1.0 / len(dataloader) * 100
                logger.info(
                    "epoch={}, i={}, N={}, rate={}%, errD={}, errG={}".format(
                        epoch, i, len(dataloader), rate, errD, errG))
        #end-for
        save_image(fake.data,
                   '%s/fake_samples_epoch_%03d.png' %
                   (conf.checkpoints, epoch),
                   normalize=True)
        torch.save(netG.state_dict(),
                   '%s/netG_%03d.pth' % (conf.checkpoints, epoch))
        torch.save(netD.state_dict(),
                   '%s/netD_%03d.pth' % (conf.checkpoints, epoch))
コード例 #4
0
            errD_real = netD(realData)
            errD_real = errD_real.mean()
            errD_real.backward(one)
            # train with fake
            z.data.resize_(batchSize_now, nz, 1, 1).normal_()
            fakeData = netG(z)
            # pdb.set_trace()
            errD_fake = netD(fakeData.detach())
            errD_fake = errD_fake.mean()
            errD_fake.backward(mone)
            optimizerD.step()
            id += 1
        ############################
        # (2) Update G network
        ###########################
        netG.zero_grad()
        errG = netD(fakeData)
        errG = errG.mean()
        errG.backward(one)
        optimizerG.step()
        ig += 1
        hhh = fakeData.data.cpu()
        hhh = hhh / 2 + 0.5
        vis.image(torchvision.utils.make_grid(hhh), win=win)
        print('epoch %d, batch %d, Dreal: %.4f, Dfake: %.4f, errG: %.4f' %
              (it, ib, errD_real.data[0], errD_fake.data[0], errG.data[0]))
    # do checkpointing
    hhh = netG(z).data.cpu()
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (checkRoot, it))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (checkRoot, it))
コード例 #5
0
def train():
    # change opt
    # for k_, v_ in kwargs.items():
    #     setattr(opt, k_, v_)

    device = torch.device('cuda') if torch.cuda.is_available else torch.device(
        'cpu')

    if opt.vis:
        from visualizer import Visualizer
        vis = Visualizer(opt.env)

    # rescale to -1~1
    transform = transforms.Compose([
        transforms.Resize(opt.image_size),
        transforms.CenterCrop(opt.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.ImageFolder(opt.data_path, transform=transform)

    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)

    netd = NetD(opt)
    netg = NetG(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(torch.load(opt.netd_path),
                             map_location=map_location)
    if opt.netg_path:
        netg.load_state_dict(torch.load(opt.netg_path),
                             map_location=map_location)

    if torch.cuda.is_available():
        netd.to(device)
        netg.to(device)

    # 定义优化器和损失
    optimizer_g = torch.optim.Adam(netg.parameters(),
                                   opt.lr1,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr2,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss().to(device)

    # 真label为1, noises是输入噪声
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if torch.cuda.is_available():
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    for epoch in range(opt.max_epoch):
        print("epoch:", epoch, end='\r')
        # sys.stdout.flush()
        for ii, (img, _) in enumerate(dataloader):
            real_img = Variable(img)
            if torch.cuda.is_available():
                real_img = real_img.cuda()

            # 训练判别器, real -> 1, fake -> 0
            if (ii + 1) % opt.d_every == 0:
                # real
                optimizer_d.zero_grad()
                output = netd(real_img)
                # print(output.shape, true_labels.shape)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()
                # fake
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 随机噪声生成假图
                fake_output = netd(fake_img)
                error_d_fake = criterion(fake_output, fake_labels)
                error_d_fake.backward()
                # update optimizer
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            # 训练生成器, 让生成器得到的图片能够被判别器判别为真
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 进行可视化
                # if os.path.exists(opt.debug_file):
                #     import ipdb
                #     ipdb.set_trace()

                fix_fake_img = netg(fix_noises)
                vis.images(
                    fix_fake_img.detach().cpu().numpy()[:opt.batch_size] * 0.5
                    + 0.5,
                    win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:opt.batch_size] * 0.5 +
                           0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_img.data[:opt.batch_size],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
コード例 #6
0
ファイル: train.py プロジェクト: LehaoLin/getbaby
        noise = noise.to(device)
        fake = netG(noise)  # 生成假图
        output = netD(fake.detach())  #避免梯度传到G,因为G不用更新
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_fake + errD_real
        optimizerD.step()

        # 固定鉴别器D,训练生成器G
        optimizerG.zero_grad()
        # 让D尽可能把G生成的假图判别为1
        label.data.fill_(real_label)
        label = label.to(device)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f' %
              (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))

    # vutils.save_image(fake.data,
    #                   '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
    #                   normalize=True)
    # torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))
    # torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))

    vutils.save_image(fake.data, 'imgs/fake_samples.png', normalize=True)
    torch.save(netG.state_dict(), 'imgs/netG.pth')
    torch.save(netD.state_dict(), 'imgs/netD.pth')
コード例 #7
0
ファイル: train.py プロジェクト: haewngX/DCGAN-Pytorch
                   errG.item(), D_x, D_G_z1, D_G_z2))

        # 记录损失画图 Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 观察生成器 Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs - 1) and
                                  (i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            # 保存图片
            vutils.save_image(fake,
                              '%s/fake_samples_epoch_%03d.png' %
                              ('inm/', epoch),
                              normalize=True)
        iters += 1

    # 保存模型 do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % ('models/', epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % ('models/', epoch))

# 画损失图 Plot G_losses、D_losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
コード例 #8
0
        errD_real.backward()
        # 让D尽可能把假图片判别为0
        label.data.fill_(fake_label)
        noise = torch.randn(opt.batchSize, opt.nz, 1, 1)
        noise = noise.to(device)
        fake = netG(noise)  # 生成假图
        output = netD(fake.detach())  # 避免梯度传到G,因为G不用更新
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_fake + errD_real
        optimizerD.step()

        # 固定鉴别器D,训练生成器G
        optimizerG.zero_grad()
        # 让D尽可能把G生成的假图判别为1
        label.data.fill_(real_label)
        label = label.to(device)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f' %
              (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))

    vutils.save_image(fake.data,
                      '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                      normalize=True)
    torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))
コード例 #9
0
            noise = noise.to(device)
            fake = netG(noise)  # Generate fake map
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            errD = errD_fake + errD_real
            optimizerD.step()

            # Step 2: Fix discriminator D and train generator G
            optimizerG.zero_grad()
            label.data.fill_(real_label)
            label = label.to(device)
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            optimizerG.step()

            print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f' %
                  (epoch, opt.epoch, i, len(dataloader), errD.item(),
                   errG.item()))

        vutils.save_image(fake.data,
                          '%s/fake_samples_epoch_%03d.png' %
                          (opt.output, epoch),
                          normalize=True)
        torch.save(netG.state_dict(),
                   '%s/netG_%03d.pth' % (opt.model_path, epoch))
        torch.save(netD.state_dict(),
                   '%s/netD_%03d.pth' % (opt.model_path, epoch))

    print("end training")
コード例 #10
0
ファイル: train.py プロジェクト: tA-bot-git/TAC-GAN_JeHaYaFa
class TACGAN():
    def __init__(self, args):
        self.lr = args.lr
        self.cuda = args.use_cuda
        self.batch_size = args.batch_size
        self.image_size = args.image_size
        self.epochs = args.epochs
        self.data_root = args.data_root
        self.dataset = args.dataset
        self.num_classes = args.num_cls
        self.save_dir = args.save_dir
        self.save_prefix = args.save_prefix
        self.continue_training = args.continue_training
        self.netG_path = args.netg_path
        self.netD_path = args.netd_path
        self.save_after = args.save_after
        self.trainset_loader = None
        self.evalset_loader = None
        self.num_workers = args.num_workers
        self.n_z = args.n_z  # length of the noise vector
        self.nl_d = args.nl_d
        self.nl_g = args.nl_g
        self.nf_g = args.nf_g
        self.nf_d = args.nf_d
        self.bce_loss = nn.BCELoss()
        self.nll_loss = nn.NLLLoss()
        self.netD = NetD(n_cls=self.num_classes, n_t=self.nl_d, n_f=self.nf_d)
        self.netG = NetG(n_z=self.n_z, n_l=self.nl_g, n_c=self.nf_g)

        # convert to cuda tensors
        if self.cuda and torch.cuda.is_available():
            print('CUDA is enabled')
            self.netD = self.netD.cuda()
            self.netG = self.netG.cuda()
            self.bce_loss = self.bce_loss.cuda()
            self.nll_loss = self.nll_loss.cuda()

        # optimizers for netD and netG
        self.optimizerD = optim.Adam(params=self.netD.parameters(),
                                     lr=self.lr,
                                     betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(params=self.netG.parameters(),
                                     lr=self.lr,
                                     betas=(0.5, 0.999))

        # create dir for saving checkpoints and other results if do not exist
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        if not os.path.exists(os.path.join(self.save_dir, 'netd_checkpoints')):
            os.makedirs(os.path.join(self.save_dir, 'netd_checkpoints'))
        if not os.path.exists(os.path.join(self.save_dir, 'netg_checkpoints')):
            os.makedirs(os.path.join(self.save_dir, 'netg_checkpoints'))
        if not os.path.exists(os.path.join(self.save_dir, 'generated_images')):
            os.makedirs(os.path.join(self.save_dir, 'generated_images'))

    # start training process
    def train(self):
        # write to the log file and print it
        log_msg = '********************************************\n'
        log_msg += '            Training Parameters\n'
        log_msg += 'Dataset:%s\nImage size:%dx%d\n' % (
            self.dataset, self.image_size, self.image_size)
        log_msg += 'Batch size:%d\n' % (self.batch_size)
        log_msg += 'Number of epochs:%d\nlr:%f\n' % (self.epochs, self.lr)
        log_msg += 'nz:%d\nnl-d:%d\nnl-g:%d\n' % (self.n_z, self.nl_d,
                                                  self.nl_g)
        log_msg += 'nf-g:%d\nnf-d:%d\n' % (self.nf_g, self.nf_d)
        log_msg += '********************************************\n\n'
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),
                  'a') as log_file:
            log_file.write(log_msg)
        # load trainset and evalset
        imtext_ds = ImTextDataset(data_dir=self.data_root,
                                  dataset=self.dataset,
                                  train=True,
                                  image_size=self.image_size)
        self.trainset_loader = DataLoader(dataset=imtext_ds,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          num_workers=2)
        print("Dataset loaded successfuly")
        # load checkpoints for continuing training
        if args.continue_training:
            self.loadCheckpoints()

        # repeat for the number of epochs
        netd_losses = []
        netg_losses = []
        for epoch in range(self.epochs):
            netd_loss, netg_loss = self.trainEpoch(epoch)
            netd_losses.append(netd_loss)
            netg_losses.append(netg_loss)
            self.saveGraph(netd_losses, netg_losses)
            #self.evalEpoch(epoch)
            self.saveCheckpoints(epoch)

    # train epoch
    def trainEpoch(self, epoch):
        self.netD.train()  # set to train mode
        self.netG.train()  #! set to train mode???

        netd_loss_sum = 0
        netg_loss_sum = 0
        start_time = time()
        for i, (images, labels, captions,
                _) in enumerate(self.trainset_loader):
            batch_size = images.size(
                0
            )  # !batch size my be different (from self.batch_size) for the last batch
            images, labels, captions = Variable(images), Variable(
                labels), Variable(captions)  # !labels should be LongTensor
            labels = labels.type(
                torch.FloatTensor
            )  # convert to FloatTensor (from DoubleTensor)
            lbl_real = Variable(torch.ones(batch_size, 1))
            lbl_fake = Variable(torch.zeros(batch_size, 1))
            noise = Variable(torch.randn(batch_size,
                                         self.n_z))  # create random noise
            noise.data.normal_(0, 1)  # normalize the noise
            rnd_perm1 = torch.randperm(
                batch_size
            )  # random permutations for different sets of training tuples
            rnd_perm2 = torch.randperm(batch_size)
            rnd_perm3 = torch.randperm(batch_size)
            rnd_perm4 = torch.randperm(batch_size)
            if self.cuda:
                images, labels, captions = images.cuda(), labels.cuda(
                ), captions.cuda()
                lbl_real, lbl_fake = lbl_real.cuda(), lbl_fake.cuda()
                noise = noise.cuda()
                rnd_perm1, rnd_perm2, rnd_perm3, rnd_perm4 = rnd_perm1.cuda(
                ), rnd_perm2.cuda(), rnd_perm3.cuda(), rnd_perm4.cuda()

            ############### Update NetD ###############
            self.netD.zero_grad()
            # train with wrong image, wrong label, real caption
            outD_wrong, outC_wrong = self.netD(images[rnd_perm1],
                                               captions[rnd_perm2])
            lossD_wrong = self.bce_loss(outD_wrong, lbl_fake)
            lossC_wrong = self.bce_loss(outC_wrong, labels[rnd_perm1])

            # train with real image, real label, real caption
            outD_real, outC_real = self.netD(images, captions)
            lossD_real = self.bce_loss(outD_real, lbl_real)
            lossC_real = self.bce_loss(outC_real, labels)

            # train with fake image, real label, real caption
            fake = self.netG(noise, captions)
            outD_fake, outC_fake = self.netD(fake.detach(),
                                             captions[rnd_perm3])
            lossD_fake = self.bce_loss(outD_fake, lbl_fake)
            lossC_fake = self.bce_loss(outC_fake, labels[rnd_perm3])

            # backward and forwad for NetD
            netD_loss = lossC_wrong + lossC_real + lossC_fake + lossD_wrong + lossD_real + lossD_fake
            netD_loss.backward()
            self.optimizerD.step()

            ########## Update NetG ##########
            # train with fake data
            self.netG.zero_grad()
            noise.data.normal_(0, 1)  # normalize the noise vector
            fake = self.netG(noise, captions[rnd_perm4])
            d_fake, c_fake = self.netD(fake, captions[rnd_perm4])
            lossD_fake_G = self.bce_loss(d_fake, lbl_real)
            lossC_fake_G = self.bce_loss(c_fake, labels[rnd_perm4])
            netG_loss = lossD_fake_G + lossC_fake_G
            netG_loss.backward()
            self.optimizerG.step()

            netd_loss_sum += netD_loss.data[0]
            netg_loss_sum += netG_loss.data[0]
            ### print progress info ###
            print(
                'Epoch %d/%d, %.2f%% completed. Loss_NetD: %.4f, Loss_NetG: %.4f'
                % (epoch, self.epochs,
                   (float(i) / len(self.trainset_loader)) * 100,
                   netD_loss.data[0], netG_loss.data[0]))

        end_time = time()
        netd_avg_loss = netd_loss_sum / len(self.trainset_loader)
        netg_avg_loss = netg_loss_sum / len(self.trainset_loader)
        epoch_time = (end_time - start_time) / 60
        log_msg = '-------------------------------------------\n'
        log_msg += 'Epoch %d took %.2f minutes\n' % (epoch, epoch_time)
        log_msg += 'NetD average loss: %.4f, NetG average loss: %.4f\n\n' % (
            netd_avg_loss, netg_avg_loss)
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),
                  'a') as log_file:
            log_file.write(log_msg)
        return netd_avg_loss, netg_avg_loss

    # eval epoch
    def evalEpoch(self, epoch):
        #self.netD.eval()
        #self.netG.eval()
        return 0

    # draws and saves the loss graph upto the current epoch
    def saveGraph(self, netd_losses, netg_losses):
        plt.plot(netd_losses, color='red', label='NetD Loss')
        plt.plot(netg_losses, color='blue', label='NetG Loss')
        plt.xlabel('epoch')
        plt.ylabel('error')
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.save_dir, 'loss_graph.png'))
        plt.close()

    # save after each epoch
    def saveCheckpoints(self, epoch):
        if epoch % self.save_after == 0:
            name_netD = "netd_checkpoints/netD_" + self.save_prefix + "_epoch_" + str(
                epoch) + ".pth"
            name_netG = "netg_checkpoints/netG_" + self.save_prefix + "_epoch_" + str(
                epoch) + ".pth"
            torch.save(self.netD.state_dict(),
                       os.path.join(self.save_dir, name_netD))
            torch.save(self.netG.state_dict(),
                       os.path.join(self.save_dir, name_netG))
            print("Checkpoints for epoch %d saved successfuly" % (epoch))

    # load checkpoints to continue training
    def loadCheckpoints(self):
        self.netG.load_state_dict(torch.load(self.netG_path))
        self.netD.load_state_dict(torch.load(self.netD_path))
        print("Checkpoints loaded successfuly")