Beispiel #1
0
def train(**kwargs):
    vis = Visualizer('GAN')  # 可视化
    for epoch in range(num_epoch):
        for i, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.cuda()

            # train D
            D_optimizer.zero_grad()

            # 将 real_img 判断成 1
            output = D(real_img)
            D_loss_real = criterion(output, true_labels)
            D_loss_real.backward()

            # 将 fake_img 判断成 0
            noises.data.copy_(torch.randn(batch_size, z_dimension, 1, 1))
            fake_img = G(noises).detach()  # draw fake pic
            output = D(fake_img)
            D_loss_fake = criterion(output, fake_labels)
            D_loss_fake.backward()
            D_optimizer.step()

            D_loss = D_loss_real + D_loss_fake
            errorD_meter.add(D_loss.item())

            if i % 5 == 0:  # train G every 5 batches
                G_optimizer.zero_grad()
                noises.data.copy_(torch.randn(batch_size, z_dimension, 1, 1))
                fake_img = G(noises)
                output = D(fake_img)
                G_loss = criterion(output, true_labels)
                G_loss.backward()
                G_optimizer.step()
                errorG_meter.add(G_loss.item())

        if (epoch + 1) % 10 == 0:  # save model every 10 epochs

            if os.path.exists('/tmp/debugGAN'):
                ipdb.set_trace()
            fix_fake_imgs = G(fix_noises)
            vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5,
                       win='fake image')
            vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                       win='real image')
            vis.plot('errorD', errorD_meter.value()[0])
            vis.plot('errorG', errorG_meter.value()[0])

            torchvision.utils.save_image(fix_fake_imgs.data[:64],
                                         '%s/%s.png' % ('imgs', epoch),
                                         normalize=True,
                                         range=(-1, 1))
            torch.save(D.state_dict(), 'checkpoints/D_%s.pth' % epoch)
            torch.save(G.state_dict(), 'checkpoints/G_%s.pth' % epoch)
            errorD_meter.reset()
            errorG_meter.reset()
Beispiel #2
0
def train(**kwargs):

    #Set attributes
    for k, v in kwargs.items():
        setattr(opt, k, v)
    if opt.vis:
        visualizer = Visualizer()

    # Data
    data, word2ix, ix2word = get_data(opt)
    data = t.from_numpy(data)
    dataloader = DataLoader(data, batch_size=opt.batch_size, shuffle=True)

    # Model
    model = LyricsModel(len(word2ix), opt.embedding_dim, opt.latent_dim)
    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path, map_location="cpu"))

    # Define optimizer and loss
    optimizer = Adam(model.parameters(), lr=opt.lr)
    criterion = nn.CrossEntropyLoss()
    loss_meter = meter.AverageValueMeter()

    if opt.use_gpu:
        model.cuda()
        criterion.cuda()

    #================================================#
    #               Start Training                   #
    #================================================#

    for epoch in tqdm.tqdm(range(opt.num_epoch)):

        for (ii, data) in enumerate(dataloader):
            # Prepare data
            data = data.long().transpose(1, 0).contiguous()
            if opt.use_gpu: data = data.cuda()
            inputs, targets = Variable(data[:-1, :]), Variable(data[1:, :])
            outputs, hidden = model(inputs)

            # Initialize and backward
            optimizer.zero_grad()
            loss = criterion(outputs, targets.view(-1))
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.item())

            if (1 + ii) % opt.print_every == 0:
                print("Current Loss: %d" % loss.item())
                if opt.vis:
                    visualizer.plot('loss', loss_meter.value()[0])
        if (epoch + 1) % 20 == 0:
            t.save(model.state_dict(), 'checkpoints/%s.pth' % epoch)
def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedCelebA(obs_prob=args.obs_prob)
        mask_str = f'{args.mask}_{args.obs_prob}'
    elif args.mask == 'block':
        data = BlockMaskedCelebA(block_len=args.block_len)
        mask_str = f'{args.mask}_{args.block_len}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'celeba-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir, loss_xlim=(0, args.epoch))

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch >= args.ae_start:
            ae_weight = args.ae

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            if critic_updates < n_critic:
                z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                w_dist = real_score - fake_score
                D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                                (x_gen * mask_gen, z_gen))

                critic_optimizer.zero_grad()
                D_loss.backward()
                critic_optimizer.step()

                loss_breakdown['D'] += D_loss.item()

                critic_updates += 1
            else:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x,
                                                             mask,
                                                             ae=(args.ae > 0))

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                mmd_loss = 0
                if args.mmd > 0:
                    mmd_loss = mmd(z_enc, z_gen)
                    loss += mmd_loss * args.mmd

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                if torch.is_tensor(ae_loss):
                    loss_breakdown['AE'] += ae_loss.item()
                if torch.is_tensor(mmd_loss):
                    loss_breakdown['MMD'] += mmd_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Beispiel #4
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    # 把网络放入GPU中,如果有的话
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    # binart cross entropy
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0,label是batchsize长度的数组
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                # 将real image放入,然后loss和true label求取
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误,注意这里的加图片是noise从generator中拿到的!!!
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()

                # 上面两个都反向传播误差之后,才进行更新(optimizer自动获取parameter里面的grad然后更新)
                # !!!非常重要,这里只有判别器的parameter更新了,generator的权重不会更新!
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                # fake image经过generater之后变成fake image,之后这个image会努力使其被判定为true
                # 但是需要注意,这里优化的参数只有生成器的,没有判别器的
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Beispiel #5
0
    def train(self, dataset):
        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(self.epoch):
            for ii, data in tqdm.tqdm(enumerate(dataset)):
                real_img = data['I']
                tri_img = data['T']

                if self.com_loss:
                    bg_img = data['B'].to(self.device)
                    fg_img = data['F'].to(self.device)

                # input to the G
                input_img = t.tensor(np.append(real_img.numpy(), tri_img.numpy(), axis=1)).to(self.device)

                # real_alpha
                real_alpha = data['A'].to(self.device)

                # vis.images(real_img.numpy()*0.5 + 0.5, win='input_real_img')
                # vis.images(real_alpha.cpu().numpy()*0.5 + 0.5, win='real_alpha')
                # vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')

                # train D
                if ii % self.d_every == 0:
                    self.D_optimizer.zero_grad()

                    # real_img_d = input_img[:, 0:3, :, :]
                    tri_img_d = input_img[:, 3:4, :, :]

                    # 真正的alpha 交给判别器判断
                    if self.com_loss:
                        real_d = self.D(input_img)
                    else:
                        real_d = self.D(t.cat([real_alpha, tri_img_d], dim=1))

                    target_real_label = t.tensor(1.0)
                    target_real = target_real_label.expand_as(real_d).to(self.device)

                    loss_d_real = self.D_criterion(real_d, target_real)
                    #loss_d_real.backward()

                    # 生成器生成fake_alpha 交给判别器判断
                    fake_alpha = self.G(input_img)
                    if self.com_loss:
                        fake_img = fake_alpha*fg_img + (1 - fake_alpha) * bg_img
                        fake_d = self.D(t.cat([fake_img, tri_img_d], dim=1))
                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_d], dim=1))
                    target_fake_label = t.tensor(0.0)

                    target_fake = target_fake_label.expand_as(fake_d).to(self.device)

                    loss_d_fake = self.D_criterion(fake_d, target_fake)

                    loss_D = loss_d_real + loss_d_fake
                    loss_D.backward()
                    self.D_optimizer.step()
                    self.D_error_meter.add(loss_D.item())

                # train G
                if ii % self.g_every == 0:
                    self.G_optimizer.zero_grad()

                    real_img_g = input_img[:, 0:3, :, :]
                    tri_img_g = input_img[:, 3:4, :, :]

                    fake_alpha = self.G(input_img)
                    # fake_alpha 与 real_alpha的L1 loss
                    loss_g_alpha = self.G_criterion(fake_alpha, real_alpha)
                    loss_G = loss_g_alpha
                    self.Alpha_loss_meter.add(loss_g_alpha.item())

                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 - fake_alpha) * bg_img
                        loss_g_cmp = self.G_criterion(fake_img, real_img_g)

                        # 迷惑判别器
                        fake_d = self.D(t.cat([fake_img, tri_img_g], dim=1))
                        self.Com_loss_meter.add(loss_g_cmp.item())
                        loss_G = loss_G + loss_g_cmp

                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_g], dim=1))
                    target_fake = t.tensor(1.0).expand_as(fake_d).to(self.device)
                    loss_g_d = self.D_criterion(fake_d, target_fake)

                    self.Adv_loss_meter.add(loss_g_d.item())

                    loss_G = loss_G + loss_g_d

                    loss_G.backward()
                    self.G_optimizer.step()
                    self.G_error_meter.add(loss_G.item())

                if self.visual and ii % 20 == 0:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    #vis.plot('errorg', self.G_error_meter.value()[0])
                    vis.plot('errorg', np.array([self.Adv_loss_meter.value()[0], self.Alpha_loss_meter.value()[0],
                                                 self.Com_loss_meter.value()[0]]), legend=['adv_loss', 'alpha_loss',
                                                                                           'com_loss'])

                    vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')
                    vis.images(real_img.cpu().numpy() * 0.5 + 0.5, win='relate_real_input')
                    vis.images(real_alpha.cpu().numpy() * 0.5 + 0.5, win='relate_real_alpha')
                    vis.images(fake_alpha.detach().cpu().numpy(), win='fake_alpha')
                    if self.com_loss:
                        vis.images(fake_img.detach().cpu().numpy()*0.5 + 0.5, win='fake_img')
            self.G_error_meter.reset()
            self.D_error_meter.reset()

            self.Alpha_loss_meter.reset()
            self.Com_loss_meter.reset()
            self.Adv_loss_meter.reset()
            if epoch % 5 == 0:
                t.save(self.D.state_dict(), self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(), self.save_dir + '/netG' + '/netG_%s.pth' % epoch)

        return
Beispiel #6
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        print("use pretrained models.")
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            # tv.utils.save_image(fix_fake_imgs.data[:64], '{0}_{1}.png'.format(opt.save_path, epoch), normalize=True, range=(-1, 1))
            t.save(netd.state_dict(), './checkpoints/tiny_netd_%s.pth' % (epoch))
            t.save(netg.state_dict(), './checkpoints/tiny_netg_%s.pth' % (epoch))
            errord_meter.reset()
            errorg_meter.reset()
        # trainint trick : clip_gradient
        # https://blog.csdn.net/u010814042/article/details/76154391
        # solve Gradient explosion problem
        clip_gradient(optimizer=optim, grad_clip=opt.grad_clip)

        # step optimizer
        optim.step()

        # plot for loss and accuracy
        if train_epoch % 50 == 0:
            if opt.use_cuda:
                loss_data = loss.cpu().data[0]
            else:
                loss_data = loss.data[0]
            print("{} EPOCH {} batch: train loss {}".format(
                i, train_epoch, loss_data))

            # vis loss
            vis.plot('loss', loss_data)

    # evaluate on test for this epoch
    accuracy = evaluate(model, test_iter, opt)
    vis.log("{} EPOCH, accuaracy : {}".format(i, accuracy))
    vis.plot('accuracy', accuracy)

    # handel best model, update best model , best_lstm.pth
    if accuracy > best_accuaracy:
        best_accuaracy = accuracy
        torch.save(model.state_dict(), './best_{}.pth'.format(opt.model))

print('best accuracy: {}'.format(best_accuaracy))
    def train(self, dataset):
        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(self.epoch):
            for ii, data in tqdm.tqdm(enumerate(dataset)):
                real_img = data['I']
                tri_img = data['T']  #Trimap

                if self.com_loss:
                    bg_img = data['B'].to(self.device)  #Background image
                    fg_img = data['F'].to(self.device)  #Foreground image

                # input to the G, 4 Channel, Image and Trimap concatenated
                input_img = t.tensor(
                    np.append(real_img.numpy(), tri_img.numpy(),
                              axis=1)).to(self.device)

                # real_alpha
                real_alpha = data['A'].to(self.device)

                # vis.images(real_img.numpy()*0.5 + 0.5, win='input_real_img')
                # vis.images(real_alpha.cpu().numpy()*0.5 + 0.5, win='real_alpha')
                # vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')

                # train D
                if ii % self.d_every == 0:
                    self.D_optimizer.zero_grad()

                    # real_img_d = input_img[:, 0:3, :, :]
                    tri_img_d = input_img[:, 3:4, :, :]

                    #alpha
                    if self.com_loss:
                        real_d = self.D(input_img)
                    else:
                        real_d = self.D(t.cat([real_alpha, tri_img_d], dim=1))

                    target_real_label = t.tensor(1.0)  #1 for real
                    #The shape of real_d would be NxN
                    target_real = target_real_label.expand_as(real_d).to(
                        self.device)

                    loss_d_real = self.D_criterion(real_d, target_real)

                    #fake_alpha, is the predicted alpha by the generator
                    fake_alpha = self.G(input_img)
                    if self.com_loss:
                        #Constructing the fake Image
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        fake_d = self.D(t.cat([fake_img, tri_img_d], dim=1))
                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_d], dim=1))
                    target_fake_label = t.tensor(0.0)

                    target_fake = target_fake_label.expand_as(fake_d).to(
                        self.device)

                    loss_d_fake = self.D_criterion(fake_d, target_fake)

                    loss_D = loss_d_real + loss_d_fake
                    #Backpropagation of the  discriminator loss
                    loss_D.backward()
                    self.D_optimizer.step()
                    self.D_error_meter.add(loss_D.item())

                # train G
                if ii % self.g_every == 0:
                    #Initialize the Optimizer
                    self.G_optimizer.zero_grad()

                    real_img_g = input_img[:, 0:3, :, :]
                    tri_img_g = input_img[:, 3:4, :, :]

                    fake_alpha = self.G(input_img)
                    # fake_alpha  is the output of the Generator
                    loss_g_alpha = self.G_criterion(fake_alpha, real_alpha)
                    #alpha_loss, difference between predicted alpha and the real alpha
                    loss_G = loss_g_alpha
                    self.Alpha_loss_meter.add(loss_g_alpha.item())

                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        loss_g_cmp = self.G_criterion(
                            fake_img, real_img_g)  #Composition Loss

                        fake_d = self.D(t.cat([fake_img, tri_img_g], dim=1))
                        self.Com_loss_meter.add(loss_g_cmp.item())
                        loss_G = loss_G + loss_g_cmp

                    else:
                        fake_d = self.D(t.cat([fake_alpha, tri_img_g], dim=1))
                    target_fake = t.tensor(1.0).expand_as(fake_d).to(
                        self.device)
                    #The target of Generator is to make the Discriminator ouptut 1
                    loss_g_d = self.D_criterion(fake_d, target_fake)

                    self.Adv_loss_meter.add(loss_g_d.item())

                    loss_G = loss_G + loss_g_d

                    loss_G.backward()
                    self.G_optimizer.step()
                    self.G_error_meter.add(loss_G.item())

                if self.visual and ii % 20 == 0:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    #vis.plot('errorg', self.G_error_meter.value()[0])
                    vis.plot('errorg',
                             np.array([
                                 self.Adv_loss_meter.value()[0],
                                 self.Alpha_loss_meter.value()[0],
                                 self.Com_loss_meter.value()[0]
                             ]),
                             legend=['adv_loss', 'alpha_loss', 'com_loss'])

                    vis.images(tri_img.numpy() * 0.5 + 0.5, win='tri_map')
                    vis.images(real_img.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_input')
                    vis.images(real_alpha.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_alpha')
                    vis.images(fake_alpha.detach().cpu().numpy(),
                               win='fake_alpha')
                    if self.com_loss:
                        vis.images(fake_img.detach().cpu().numpy() * 0.5 + 0.5,
                                   win='fake_img')
            self.G_error_meter.reset()
            self.D_error_meter.reset()

            self.Alpha_loss_meter.reset()
            self.Com_loss_meter.reset()
            self.Adv_loss_meter.reset()
            if epoch % 5 == 0:
                t.save(self.D.state_dict(),
                       self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(),
                       self.save_dir + '/netG' + '/netG_%s.pth' % epoch)

        return
Beispiel #9
0
def train(**kwargs):

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.vis is True:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),  #change value to (0,1)
        tv.transforms.Lambda(lambda x: x * 255)
    ])  #change value to (0,255)
    dataset = tv.datasets.ImageFolder(opt.data_root, transforms)

    dataloader = data.DataLoader(dataset, opt.batch_size)  #value is (0,255)

    transformer = TransformerNet()

    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))

    vgg = VGG16().eval()
    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:

        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    style_v = Variable(style.unsqueeze(0), volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()  #(0,255)
            x = Variable(x)
            y = transformer(x)  #(0,255)
            y = utils.normalize_batch(y)  #(-2,2)
            x = utils.normalize_batch(x)  #(-2,2)

            features_y = vgg(y)
            features_x = vgg(x)

            #calculate the content loss: it's only used relu2_2
            # i think should add more layer's result to calculate the result like: w1*relu2_2+w2*relu3_2+w3*relu3_3+w4*relu4_3
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            content_meter.add(content_loss.data)

            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):

                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_meter.add(style_loss.data)

            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if (ii + 1) % (opt.plot_every) == 0:

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])

                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Beispiel #10
0
        # loss = torch.std(loss_matrix) + loss_matrix.mean()

        # use smooth_l1_loss
        # use sum(abs(loss-mean)) to evaluate the volatility of the model prediction
        # use l1 loss to make
        # the problem is
        # loss_matrix = F.smooth_l1_loss(output, target, reduce=False)
        # loss = torch.abs(loss_matrix - loss_matrix.mean()).mean() + loss_matrix.mean()

        # loss_matrix
        loss.backward()

        optimizer.step()
        # plot loss for the epoch for every plot_batch
        if batch_index % plot_batch == 0:
            vis.plot('mixed_loss', min(loss.data, 10))
            vis.plot('origin_loss', min(loss_matrix.mean().data, 10))
            vis.plot('std', torch.std(loss_matrix).data)

    if task_type == 'regression':
        # 评估本轮的ic值
        model.eval()
        y_pred = torch.Tensor()
        y_true = torch.Tensor()

        for batch_index, (data, target) in enumerate(test_loader):
            data, target = Variable(data), Variable(target)
            data, target = data.float(), target.float()
            output = model(data)
            # update y_pred and y_true
            y_pred = torch.cat([y_pred, output])
Beispiel #11
0
def train(**kwargs):
	for k_,v_ in kwargs.items():
		setattr(opt,k_,v_)
    
	if opt.vis is True:
		from visualize import Visualizer
		vis = Visualizer(opt.env)
	transforms = tv.transforms.Compose([
				tv.transforms.Scale(opt.image_size),
				tv.transforms.CenterCrop(opt.image_size),
				tv.transforms.ToTensor(),
				tv.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

	dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)

	dataloader = t.utils.data.DataLoader(dataset,
									batch_size= opt.batch_size,
									shuffle = True,
									num_workers = opt.num_workers,
									drop_last = True)
	netg,netd = NetG(opt),NetD(opt)

	map_location = lambda storage, loc: storage
	if opt.netd_path:
		netd.load_state_dict(t.load(opt.netd_path,map_location = map_location))
	if opt.netg_path:
		netg.load_state_dict(t.load(opt.netg_path,map_location= map_location))

	optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1,0.999))
	optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1,0.999))

	criterion = t.nn.BCELoss()

	true_labels = Variable(t.ones(opt.batch_size))
	fake_labels = Variable(t.zeros(opt.batch_size))

	fix_noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))

	noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))

	if opt.use_gpu:
		netd.cuda()
		netg.cuda()
		criterion.cuda()
		true_labels,fake_labels = true_labels.cuda(),fake_labels.cuda()
		fix_noises,noises = fix_noises.cuda(),noises.cuda()

	errord_meter = AverageValueMeter()
	errorg_meter = AverageValueMeter()

	epoches =range(opt.max_epoch)
	for epoch in iter(epoches):

		for ii, (img,_) in tqdm(enumerate(dataloader)):

			real_img = Variable(img)
			if opt.use_gpu:
				real_img = real_img.cuda()

			if (ii +1 )% opt.d_every == 0:

				optimizer_d.zero_grad()
				output = netd(real_img)
				error_d_real = criterion(output,true_labels)
				error_d_real.backward()

				noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
				fake_img = netg(noises).detach()
				fake_output = netd(fake_img)
				error_d_fake = criterion(fake_output,fake_labels)
				error_d_fake.backward()

				optimizer_d.step()

				error_d =error_d_real + error_d_fake 
				errord_meter.add(error_d.data)


			if (ii + 1)%opt.g_every == 0:

				optimizer_g.zero_grad()
				noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
				fake_img = netg(noises)
				fake_output = netd(fake_img)

				error_g = criterion(fake_output,true_labels)
				error_g.backward()
				optimizer_g.step()

				errorg_meter.add(error_g.data)


			if opt.vis and ii%opt.plot_every == opt.plot_every -1 :
				#visualize

				if os.path.exists(opt.debug_file):
					ipdb.set_trace()

				fix_fake_img = netg(fix_noises)
				vis.images(fix_fake_img.data.cpu().numpy()[:16]*0.5 + 0.5 ,win='fixfake')
				vis.images(real_img.data.cpu().numpy()[:16]*0.5+0.5,win='real')

				vis.plot('errord',errord_meter.value()[0])
				vis.plot('errorg',errorg_meter.value()[0])

			if (epoch+1) % opt.save_every ==0:

				tv.utils.save_image(fix_fake_img.data[:64],'%s/%s.png'%(opt.save_path,epoch),normalize=True,range=(-1,1))
				t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
				t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
				errord_meter.reset()
				errorg_meter.reset()
Beispiel #12
0
    def train(self, dataset):
        vis = Visualizer('alphaGAN_train')

        for epoch in range(self.epoch):
            for ii, data in tqdm.tqdm(enumerate(dataset)):
                real_img = data['I'].to(self.device)
                tri_img = data['T']

                bg_img = data['B'].to(self.device)
                fg_img = data['F'].to(self.device)

                # input to the G
                input_img = t.tensor(
                    np.append(real_img.cpu().numpy(), tri_img.numpy(),
                              axis=1)).to(self.device)

                # real_alpha
                real_alpha = data['A'].to(self.device)

                # vis.images(real_img.numpy()*0.5 + 0.5, win='input_real_img')
                # vis.images(real_alpha.cpu().numpy()*0.5 + 0.5, win='real_alpha')
                # vis.images(tri_img.numpy()*0.5 + 0.5, win='tri_map')

                if ii % 5 == 0:
                    self.D_optimizer.zero_grad()

                    # 真正的alpha 交给判别器判断
                    if self.com_loss:
                        real_d = self.D(real_img)
                    else:
                        real_d = self.D(real_alpha)

                    target_real_label = t.tensor(1.0)
                    target_real = target_real_label.expand_as(real_d).to(
                        self.device)

                    loss_d_real = self.D_criterion(real_d, target_real)
                    #loss_d_real.backward()

                    # 生成器生成fake_alpha 交给判别器判断
                    fake_alpha = self.G(input_img)
                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        fake_d = self.D(fake_img)
                    else:
                        fake_d = self.D(fake_alpha)
                    target_fake_label = t.tensor(0.0)

                    target_fake = target_fake_label.expand_as(fake_d).to(
                        self.device)

                    loss_d_fake = self.D_criterion(fake_d, target_fake)

                    loss_D = (loss_d_real + loss_d_fake) * 0.5
                    loss_D.backward()
                    self.D_optimizer.step()
                    self.D_error_meter.add(loss_D.item())

                if ii % 1 == 0:
                    self.G_optimizer.zero_grad()

                    fake_alpha = self.G(input_img)
                    # fake_alpha 与 real_alpha的L1 loss
                    loss_g_alpha = self.G_criterion(fake_alpha, real_alpha)
                    loss_G = 0.8 * loss_g_alpha
                    if self.com_loss:
                        fake_img = fake_alpha * fg_img + (1 -
                                                          fake_alpha) * bg_img
                        loss_g_cmp = self.G_criterion(fake_img, real_img)

                        # 迷惑判别器
                        fake_d = self.D(fake_img)

                        loss_G = loss_G + 0.2 * loss_g_cmp
                    else:
                        fake_d = self.D(fake_alpha)
                    target_fake = t.tensor(1.0).expand_as(fake_d).to(
                        self.device)
                    loss_g_d = self.D_criterion(fake_d, target_fake)

                    loss_G = loss_G + 0.2 * loss_g_d

                    loss_G.backward()
                    self.G_optimizer.step()
                    self.G_error_meter.add(loss_G.item())

                if ii % 20 == 0:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    vis.plot('errorg', self.G_error_meter.value()[0])
                    vis.images(tri_img.numpy() * 0.5 + 0.5, win='tri_map')
                    vis.images(real_img.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_input')
                    vis.images(real_alpha.cpu().numpy() * 0.5 + 0.5,
                               win='relate_real_alpha')
                    vis.images(fake_alpha.detach().cpu().numpy(),
                               win='fake_alpha')
                    if self.com_loss:
                        vis.images(fake_img.detach().cpu().numpy() * 0.5 + 0.5,
                                   win='fake_img')
            self.G_error_meter.reset()
            self.D_error_meter.reset()
            if epoch % 5 == 0:
                t.save(
                    self.D.state_dict(), '/home/zzl/model/alphaGAN/netD/' +
                    self.save_dir + '/netD_%s.pth' % epoch)
                t.save(
                    self.G.state_dict(), '/home/zzl/model/alphaGAN/netG/' +
                    self.save_dir + '/netG_%s.pth' % epoch)

        return
Beispiel #13
0
# Dataset
x_train, x_test, y_train, y_test = iris()
# x_train, x_test, y_train, y_test = digits()

max_recall = np.zeros([4])
max_c = None
fig_x = [x for x in range(-20, 20)]
vis = Visualizer(num_x=len(fig_x))

for i, x in enumerate(fig_x):
    elm = MetricELM(hidden_num=4, c=2**x)
    elm.fit(x_train, y_train)
    recall = elm.validation(x_test, y_test)

    # visualization
    print("c=", x, "R@{1,2,4,8}:", np.around(recall, decimals=3))
    vis.set(i, recall)

    # record max
    if recall[0] > max_recall[0]:
        max_recall = recall
        max_c = x

print('best:')
print("c=", max_c, "R@{1,2,4,8}:", np.around(max_recall, decimals=3))

vis.plot(x=fig_x, max_c=max_c, max_recall=max_recall)

vis.show()
# vis.save('resources/iris.png')
Beispiel #14
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env, server="172.16.6.194")

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = mydataset.AADBDataset(opt.data_path, opt.lable_path, transforms=transforms)
    # dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetmyG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)
    # huber损失
    huber = t.nn.SmoothL1Loss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    # TODO
    true_labels = t.ones([opt.batch_size, opt.weidu]).to(device)
    fake_labels = t.zeros([opt.batch_size, opt.weidu]).to(device)

    # noises 为12维
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, opt.weidu).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, opt.weidu).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()
    errorgs_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确

                output = netd(real_img)

                error_d_real = criterion(output, true_labels)
                error_d_real.backward(retain_graph=True)

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, opt.weidu))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)

                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward(retain_graph=True)

                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, opt.weidu))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                print(output.shape, true_labels.shape)
                error_g.backward(retain_graph=True)
                bb = noises.view(opt.batch_size,-1)

                print (bb.shape)
                error_s = huber(output, bb)
                error_s.backward(retain_graph=True)

                error_gs = error_s + error_g
                optimizer_g.step()
                errorg_meter.add(error_g.item())
                errorgs_meter.add(error_gs.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises).detach()
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])
                vis.plot('errorgs', errorgs_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            localtime = time.asctime(time.localtime(time.time()))
            print("要保存了")
            t.save(netd.state_dict(), 'chenck20190529/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'chenck20190529/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
            errorgs_meter.reset()
Beispiel #15
0
        # solve Gradient explosion problem
        clip_gradient(optimizer=optim, grad_clip=opt.grad_clip)

        # step optimizer
        optim.step()

        # plot for loss and accuracy
        if train_epoch % 50 == 0:
            if opt.use_cuda:
                loss_data = loss.cpu().data
            else:
                loss_data = loss.data
            print("{} EPOCH {} batch: train loss {}".format(idx_epoch, train_epoch, loss_data))

            # vis loss
            vis.plot('loss', loss_data)


    # evaluate on test for this epoch
    # accuracy np.array (classes_size)
    accuracy, AUC_scores = evaluate(model, test_iter, opt) # TODO update load data in evaluate and validation
    vis.log("{} EPOCH, accuaracy : {}".format(idx_epoch, accuracy.mean()))
    vis.plot('accuracy', accuracy)
    label_list = ['toxic', 'severe_toxic', 'obscene',
            'threat', 'insult', 'identity_hate']
    print(accuracy)
    for i in range(len(label_list)):
        vis.plot(label_list[i], accuracy[i])
    # for AUC
    print(AUC_scores)
    vis.plot('AUC', AUC_scores)
Beispiel #16
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Beispiel #17
0
        clip_gradient(optimizer=optim, grad_clip=opt.grad_clip)

        # step optimizer
        optim.step()

        # plot for loss and accuracy
        if train_epoch % 50 == 0:
            if opt.use_cuda:
                loss_data = loss.cpu().data[0]
            else:
                loss_data = loss.data[0]
            print("{} EPOCH {} batch: train loss {}".format(
                i, train_epoch, loss_data))

            # vis loss
            vis.plot('loss', loss_data)

    # evaluate on test for this epoch
    accuracy = evaluate(model, test_iter, opt)
    vis.log("{} EPOCH, accuaracy : {}".format(i, accuracy))
    vis.plot('accuracy', accuracy)

    # handel best model, update best model , best_lstm.pth
    if accuracy > best_accuaracy:
        best_accuaracy = accuracy
        torch.save(model.state_dict(), './best_{}.pth'.format(opt.model))

    # for mixed classification part
    print(cum_tensor.size())
    feature_tensor = cum_tensor.mean(dim=1)
    print(feature_tensor.size())
Beispiel #18
0
def train(opt):
    vis = Visualizer(env='Caption_generator')
    opt.vocab_size = get_nwords(opt.data_path)
    opt.category_size = get_nclasses(opt.data_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_dataset, valid_dataset, test_dataset = load_dataset_cap(opt=opt)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              collate_fn=collate_fn_cap)
    model = Caption_generator(opt=opt)
    embeddings_path = os.path.join(opt.data_path, 'sentence_embeddings.pkl')
    total_embeddings = load_pkl(embeddings_path)
    infos = {}
    best_score = None
    crit = LanguageModelCriterion()
    classify_crit = ClassifierCriterion()
    rl_crit = RewardCriterion()

    model_version = 2
    MODEL_PATH = opt.infersent_model_path
    assert MODEL_PATH is not None, '--infersent_model_path is None!'
    MODEL_PATH = os.path.join(MODEL_PATH, 'infersent%s.pkl' % model_version)
    params_model = {
        'bsize': 64,
        'word_emb_dim': 300,
        'enc_lstm_dim': 2048,
        'pool_type': 'max',
        'dpout_model': 0.0,
        'version': model_version
    }
    infersent_model = InferSent(params_model)
    infersent_model.load_state_dict(torch.load(MODEL_PATH))
    infersent_model = infersent_model.to(device)
    W2V_PATH = opt.w2v_path
    assert W2V_PATH is not None, '--w2v_path is None!'
    infersent_model.set_w2v_path(W2V_PATH)
    # sentences_path = os.path.join(opt.data_path, 'sentences.pkl')
    # sentences = load_pkl(sentences_path)
    # infersent_model.build_vocab(sentences, tokenize=True)
    infersent_model.build_vocab_k_words(K=100000)
    id_word = get_itow(opt.data_path)

    if opt.start_from is not None:
        assert os.path.isdir(opt.start_from), 'opt.start_from must be a dir!'
        state_dict_path = os.path.join(opt.start_from,
                                       opt.model_name + '-bestmodel.pth')
        assert os.path.isfile(state_dict_path), 'bestmodel don\'t exist!'
        model.load_state_dict(torch.load(state_dict_path), strict=True)

        infos_path = os.path.join(opt.start_from,
                                  opt.model_name + '_infos_best.pkl')
        assert os.path.isfile(infos_path), 'infos of bestmodel don\'t exist!'
        with open(infos_path, 'rb') as f:
            infos = pickle.load(f)

        if opt.seed == 0: opt.seed = infos['opt'].seed

        best_score = infos.get('best_score', None)

    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)
    train_patience = 0
    epoch = infos.get('epoch', 0)
    loss_meter = meter.AverageValueMeter()

    while True:
        if train_patience > opt.patience: break
        loss_meter.reset()
        if opt.learning_rate_decay_start != -1 and epoch > opt.learning_rate_decay_start:
            frac = int((epoch - opt.learning_rate_decay_start) /
                       opt.learning_rate_decay_every)
            decay_factor = opt.learning_rate_decay_rate**frac
            opt.current_lr = opt.learning_rate * decay_factor
            set_lr(optimizer, opt.current_lr)
        else:
            opt.current_lr = opt.learning_rate

        if opt.scheduled_sampling_start != -1 and epoch > opt.scheduled_sampling_start:
            frac = int((epoch - opt.scheduled_sampling_start) /
                       opt.scheduled_sampling_increase_every)
            opt.sample_probability = min(
                opt.scheduled_sampling_increase_probability * frac,
                opt.scheduled_sampling_max_probability)
            model.sample_probability = opt.sample_probability

        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
            sc_flag = True
            opt.save_checkpoint_every = 250
        else:
            sc_flag = False
        # sc_flag = True

        for i, (data, caps, caps_mask, cap_classes, class_masks, feats0,
                feats1, feat_mask, pos_feat, lens, gts,
                video_id) in enumerate(train_loader):

            feats0 = feats0.to(device)
            feats1 = feats1.to(device)
            feat_mask = feat_mask.to(device)
            pos_feat = pos_feat.to(device)
            caps = caps.to(device)
            caps_mask = caps_mask.to(device)
            cap_classes = cap_classes.to(device)
            class_masks = class_masks.to(device)

            optimizer.zero_grad()
            if not sc_flag:
                words, categories = model(feats0, feats1, feat_mask, pos_feat,
                                          caps, caps_mask)
                loss_words = crit(words, caps, caps_mask)
                loss_cate = classify_crit(categories, cap_classes, caps_mask,
                                          class_masks)
                loss = loss_words + opt.weight_class * loss_cate
            elif not opt.eval_semantics:
                sample_dict = {}
                sample_dict.update(vars(opt))
                sample_dict.update({'sample_max': 0})
                probability_sample, sample_logprobs = model.sample(
                    feats0, feats1, feat_mask, pos_feat, sample_dict)
                reward = get_self_critical_reward(model, feats0, feats1,
                                                  feat_mask, pos_feat, gts,
                                                  probability_sample)
                reward = torch.from_numpy(reward).float()
                reward = reward.to(device)
                loss = rl_crit(sample_logprobs, probability_sample, reward)
            else:
                sample_dict = vars(opt)
                sample_dict.update(vars(opt))
                sample_dict.update({'sample_max': 0})
                probability_sample, sample_logprobs = model.sample(
                    feats0, feats1, feat_mask, pos_feat, sample_dict)
                reward = get_self_critical_semantics_reward(
                    id_word, infersent_model, model, feats0, feats1, feat_mask,
                    pos_feat, gts, video_id, total_embeddings,
                    probability_sample, sample_dict)
                reward = torch.from_numpy(reward).float()
                reward = reward.to(device)
                loss = rl_crit(sample_logprobs, probability_sample, reward)
            loss.backward()
            clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.detach()
            loss_meter.add(train_loss.item())

            if i % opt.visualize_every == 0:
                vis.plot('train_loss', loss_meter.value()[0])
                information = 'best_score is ' + (str(best_score) if best_score
                                                  is not None else '0.0')
                information += (' reward is ' if sc_flag else
                                ' loss is ') + str(train_loss.item())
                if sc_flag is False:
                    information += ' category_loss is ' + str(
                        loss_cate.cpu().item())
                vis.log(information)

            is_best = False
            if (i + 1) % opt.save_checkpoint_every == 0:
                if not opt.eval_semantics:
                    current_loss, current_language_state = eval(
                        infersent_model, model, crit, classify_crit,
                        valid_dataset, vars(opt))
                else:
                    current_semantics_score, current_language_state = eval(
                        infersent_model, model, crit, classify_crit,
                        valid_dataset, vars(opt))
                current_score = current_language_state[
                    'CIDEr'] if not opt.eval_semantics else current_semantics_score
                vis.log('{}'.format(
                    'cider score is ' if not opt.eval_semantics else
                    '###################semantics_score is ') +
                        str(current_score) + ' ' + '+' + str(i))
                if best_score is None or current_score > best_score:
                    is_best = True
                    best_score = current_score
                    train_patience = 0
                else:
                    train_patience += 1

                infos['opt'] = opt
                infos['iteration'] = i
                infos['best_score'] = best_score
                infos['language_state'] = current_language_state
                infos['epoch'] = epoch
                save_state_path = os.path.join(
                    opt.checkpoint_path,
                    opt.model_name + '_' + str(i) + '.pth')
                torch.save(model.state_dict(), save_state_path)
                save_infos_path = os.path.join(
                    opt.checkpoint_path,
                    opt.model_name + '_' + 'infos_' + str(i) + '.pkl')
                with open(save_infos_path, 'wb') as f:
                    pickle.dump(infos, f)

                if is_best:
                    save_state_path = os.path.join(
                        opt.checkpoint_path, opt.model_name + '-bestmodel.pth')
                    save_infos_path = os.path.join(
                        opt.checkpoint_path,
                        opt.model_name + '_' + 'infos_best.pkl')
                    torch.save(model.state_dict(), save_state_path)
                    with open(save_infos_path, 'wb') as f:
                        pickle.dump(infos, f)

        epoch += 1
Beispiel #19
0
            gen_hr = make_grid(gen_hr, nrow=8, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=8, normalize=True)
            imgs_hr = make_grid(imgs_hr, nrow=8, normalize=True)
            img_grid_gl = torch.cat((gen_hr, imgs_lr), -1)
            img_grid_hg = torch.cat((imgs_hr, gen_hr), -1)
            save_image(img_grid_hg,
                       "images/%d_hg.png" % batches_done,
                       normalize=True)
            save_image(img_grid_gl,
                       "images/%d_gl.png" % batches_done,
                       normalize=True)

            # vis.images(imgs_lr.detach().cpu().numpy()[:1] * 0.5 + 0.5, win='imgs_lr_train')
            # vis.images(gen_hr.data.cpu().numpy()[:1] * 0.5 + 0.5, win='img_gen_train')
            # vis.images(imgs_hr.data.cpu().numpy()[:1] * 0.5 + 0.5, win='img_hr_train')
            vis.plot('loss_G_train', loss_G_meter.value()[0])
            vis.plot('loss_D_train', loss_D_meter.value()[0])
            vis.plot('loss_GAN_train', loss_GAN_meter.value()[0])
            vis.plot('loss_content_train', loss_content_meter.value()[0])
            vis.plot('loss_real_train', loss_real_meter.value()[0])
            vis.plot('loss_fake_train', loss_fake_meter.value()[0])

    loss_GAN_meter.reset()
    loss_content_meter.reset()
    loss_G_meter.reset()
    loss_real_meter.reset()
    loss_fake_meter.reset()
    loss_D_meter.reset()

    # validate the generator model
    generator.eval()
def train_pvae(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=True)
    pvae = PVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    rand_z = torch.empty(args.batch_size, args.latent, device=device)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pvae' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    kl_center = (args.kl_on + args.kl_off) / 2
    kl_scale = 12 / min(args.kl_on - args.kl_off, 1)

    for epoch in range(args.epoch):
        if epoch >= args.kl_on:
            kl_weight = 1
        elif epoch < args.kl_off:
            kl_weight = 0
        else:
            kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale))
        loss_breakdown = defaultdict(float)
        for x, mask, _ in data_loader:
            x = x.to(device)
            mask = mask.to(device).float()

            optimizer.zero_grad()
            loss, _, _, _, loss_info = pvae(x,
                                            mask,
                                            args.k,
                                            kl_weight=kl_weight)
            loss.backward()
            optimizer.step()
            for name, val in loss_info.items():
                loss_breakdown[name] += val

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            x_recon = pvae.impute(test_x, test_mask, args.k)
            with torch.no_grad():
                pvae.eval()
                rand_z.normal_()
                _, x_gen = decoder(rand_z)
                pvae.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen)

        model_dict = {
            'pvae': pvae.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Beispiel #21
0
def train(opt):
    vis = Visualizer(env='Pos_generator')
    opt.vocab_size = get_nwords(opt.data_path)
    opt.category_size = get_nclasses(opt.data_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_dataset, valid_dataset, test_dataset = load_dataset_pos(opt)
    train_loader = DataLoader(train_dataset,
                              opt.batch_size,
                              shuffle=True,
                              collate_fn=collate_fn_pos)
    model = Pos_generator(opt)
    infos = {}
    best_score = None
    crit = LanguageModelCriterion()
    classify_crit = ClassifierCriterion()

    if opt.start_from is not None:
        assert os.path.isdir(opt.start_from), 'opt.start_from must be a dir'
        state_dict_path = os.path.join(opt.start_from,
                                       opt.model_name + '-bestmodel.pth')
        assert os.path.isfile(state_dict_path), 'bestmodel doesn\'t exist!'
        model.load_state_dict(torch.load(state_dict_path), strict=True)

        infos_path = os.path.join(opt.start_from,
                                  opt.model_name + '_infos_best.pkl')
        assert os.path.isfile(infos_path), 'infos of bestmodel doesn\'t exist!'
        with open(infos_path, 'rb') as f:
            infos = pickle.load(f)

        if opt.seed == 0: opt.seed = infos['opt'].seed
        best_score = infos.get('best_score', None)

    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)
    train_patience = 0
    epoch = infos.get('epoch', 0)
    loss_meter = meter.AverageValueMeter()

    while True:
        if train_patience > opt.patience: break
        loss_meter.reset()

        if opt.learning_rate_decay_start != -1 and epoch > opt.learning_rate_decay_start:
            frac = int((epoch - opt.learning_rate_decay_start) /
                       opt.learning_rate_decay_every)
            decay_factor = opt.learning_rate_decay_rate**frac
            opt.current_lr = opt.learning_rate * decay_factor
            set_lr(optimizer, opt.current_lr)
        else:
            opt.current_lr = opt.learning_rate

        if opt.scheduled_sampling_start != -1 and epoch > opt.scheduled_sampling_start:
            frac = int((epoch - opt.scheduled_sampling_start) /
                       opt.scheduled_sampling_increase_every)
            opt.scheduled_sample_probability = min(
                opt.scheduled_sampling_increase_probability * frac,
                opt.scheduled_sampling_max)
            model.scheduled_sample_probability = opt.scheduled_sample_probability

        # if torch.cuda.is_available():
        #     torch.cuda.synchronize()
        for i, (data, caps, caps_mask, cap_classes, class_masks, feats0,
                feats1, feat_mask, lens, gts,
                video_id) in enumerate(train_loader):
            caps = caps.to(device)
            caps_mask = caps_mask.to(device)
            cap_classes = cap_classes.to(device)
            feats0 = feats0.to(device)
            feats1 = feats1.to(device)
            feat_mask = feat_mask.to(device)
            class_masks = class_masks.to(device)

            cap_classes = torch.cat([cap_classes[:, -1:], cap_classes[:, :-1]],
                                    dim=1)
            new_mask = torch.zeros_like(class_masks)

            cap_classes = cap_classes.to(device)
            new_mask = new_mask.to(device)

            for j in range(class_masks.size(0)):
                index = np.argwhere(
                    class_masks.cpu().numpy()[j, :] != 0)[0][-1]
                new_mask[j, :index + 1] = 1.0

            optimizer.zero_grad()

            out = model(feats0, feats1, feat_mask, caps, caps_mask,
                        cap_classes, new_mask)
            loss = classify_crit(out, cap_classes, caps_mask, class_masks)

            loss.backward()
            clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.detach()

            loss_meter.add(train_loss.item())

            if (i + 1) % opt.visualize_every == 0:
                vis.plot('loss', loss_meter.value()[0])
                information = 'best_score is ' + (str(best_score) if best_score
                                                  is not None else '0.0')
                information += ' current_score is ' + str(train_loss.item())
                vis.log(information)

            if (i + 1) % opt.save_checkpoint_every == 0:
                # print('i am saving!!')
                eval_kwargs = {}
                eval_kwargs.update(vars(opt))
                val_loss = eval_extract.eval_and_extract(
                    model, classify_crit, valid_dataset, device, 'validation',
                    eval_kwargs, False)
                # 此处输出的是交叉熵loss, 是一个正数, loss越小说明准确性越高, 为了转化为score我们需要加一个负号, score越大准确性越高
                current_score = val_loss.cpu().item()
                is_best = False
                if best_score is None or best_score > current_score:
                    best_score = current_score
                    is_best = True
                    train_patience = 0
                else:
                    train_patience += 1
                path_of_save_model = os.path.join(opt.checkpoint_path,
                                                  str(i) + '_model.pth')
                torch.save(model.state_dict(), path_of_save_model)

                infos['iteration'] = i
                infos['epoch'] = epoch
                infos['best_val_score'] = best_score
                infos['opt'] = opt

                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.model_name + '.pkl'),
                        'wb') as f:
                    pickle.dump(infos, f)

                if is_best:
                    path_of_save_model = os.path.join(
                        opt.checkpoint_path, opt.model_name + '-bestmodel.pth')
                    torch.save(model.state_dict(), path_of_save_model)
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.model_name + '-best.pkl'),
                            'wb') as f:
                        pickle.dump(infos, f)

        epoch += 1
Beispiel #22
0
        clip_gradient(optimizer=optim, grad_clip=opt.grad_clip)

        # step optimizer
        optim.step()

        # plot for loss and accuracy
        if train_epoch % 50 == 0:
            if opt.use_cuda:
                loss_data = loss.cpu().data
            else:
                loss_data = loss.data
            print("{} EPOCH {} batch: train loss {}".format(
                idx_round, train_epoch, loss_data))

            # vis loss
            vis.plot('loss', loss_data)

    # evaluate on test for this epoch
    # accuracy, AUC_scores = evaluate(model, test_iter, opt)
    # vis.log("{} EPOCH, accuaracy : {}".format(i, accuracy))
    # vis.plot('accuracy', accuracy)
    #
    # # handel best model, update best model , best_lstm.pth
    # if accuracy > best_accuaracy:
    #     best_accuaracy = accuracy
    #     torch.save(model.state_dict(), './best_{}.pth'.format(opt.model))

    # for mixed classification part
    print(cum_tensor.size())
    feature_tensor = cum_tensor
    print(feature_tensor.size())
def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    lrate = 1e-4
    optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=lrate,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0
    ae_flat = 100

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch > ae_flat:
            ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat)

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

            real_score = critic((x * mask, z_enc)).mean()
            fake_score = critic((x_gen * mask_gen, z_gen)).mean()

            w_dist = real_score - fake_score
            D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                            (x_gen * mask_gen, z_gen))

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            loss_breakdown['D'] += D_loss.item()

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                loss_breakdown['AE'] += ae_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Beispiel #24
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    netg, netd = NetG(opt), NetD(opt)
    if opt.netd_path:
        netd.load_state_dict(
            t.load(opt.netd_path, map_location=t.device('cpu')))
    if opt.netg_path:
        netg.load_state_dict(
            t.load(opt.netg_path, map_location=t.device('cpu')))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord = 0
    errorg = 0

    epochs = range(opt.max_epoch)
    for epoch in epochs:
        for ii, (img, _) in tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if (ii + 1) % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                # 把真图片判断为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 把假图片(netg通过噪声生成的图片)判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                errord += (error_d_fake + error_d_real).item()

            if (ii + 1) % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg += error_g.item()

            if opt.vis and (ii + 1) % opt.plot_every == 0:
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord / (opt.plot_every))
                vis.plot('errorg', errorg / (opt.plot_every))
                errord = 0
                errorg = 0

            if (epoch + 1) % opt.save_every == 0:
                tv.utils.save_image(fix_fake_imgs.data[:64],
                                    '%s%s.png' % (opt.save_path, epoch),
                                    normalize=True,
                                    range=(-1, 1))
                t.save(netd.state_dict(),
                       'checkpoints/netd_%s.pth' % (epoch + 1))
                t.save(netg.state_dict(),
                       'checkpoints/netg_%s.pth' % (epoch + 1))
Beispiel #25
0
    def train(self, dataset):

        print('---------netG------------')
        print(self.G)
        print('---------netD------------')
        print(self.D)

        if self.visual:
            vis = Visualizer(self.env)

        for epoch in range(1, self.epoch):

            self.adjust_learning_rate(epoch)

            self.G.train()

            for ii, data in enumerate(dataset):
                t.cuda.empty_cache()

                real_img = data['I']
                tri_img = data['T']
                bg_img = data['B']
                fg_img = data['F']

                # input to the G
                input_img = t.cat([real_img, tri_img], dim=1).cuda()

                # real_alpha
                real_alpha = data['A'].cuda()

                #####################################
                # train G
                #####################################
                self.set_requires_grad([self.D], False)
                self.G_optimizer.zero_grad()

                real_img_g = input_img[:, 0:3, :, :]
                tri_img_g = input_img[:, 3:4, :, :]

                # tri_img_original = tri_img_g * 0.5 + 0.5

                fake_alpha = self.G(input_img)

                wi = t.zeros(tri_img_g.shape)
                wi[(tri_img_g * 255) == 128] = 1.
                t_wi = wi.cuda()

                unknown_size = t_wi.sum()

                fake_alpha = (1 - t_wi) * tri_img_g + t_wi * fake_alpha

                # alpha loss
                loss_g_alpha = self.G_criterion(fake_alpha, real_alpha, unknown_size)
                self.Alpha_loss_meter.add(loss_g_alpha.item())

                # compositional loss
                comp = fake_alpha * fg_img.cuda() + (1. - fake_alpha) * bg_img.cuda()
                loss_g_com = self.G_criterion(comp, real_img_g, unknown_size) / 3.
                self.Com_loss_meter.add(loss_g_com.item())

                '''
                vis.images(real_img.numpy() * 0.5 + 0.5, win='real_image', opts=dict(title='real_image'))
                vis.images(bg_img.numpy() * 0.5 + 0.5, win='bg_image', opts=dict(title='bg_image'))
                vis.images(fg_img.numpy() * 0.5 + 0.5, win='fg_image', opts=dict(title='fg_image'))
                vis.images(tri_img.numpy() * 0.5 + 0.5, win='trimap', opts=dict(title='trimap'))
                vis.images(real_alpha.detach().cpu().numpy() * 0.5 + 0.5, win='real_alpha', opts=dict(title='real_alpha'))
                vis.images(fake_alpha.detach().cpu().numpy(), win='fake_alpha', opts=dict(title='fake_alpha'))
                '''

                # trick D
                input_d = t.cat([comp, tri_img_g], dim=1)
                fake_d = self.D(input_d)

                target_fake = t.tensor(1.0).expand_as(fake_d).cuda()
                loss_g_d = self.D_criterion(fake_d, target_fake)

                self.Adv_loss_meter.add(loss_g_d.item())

                loss_G = 0.5 * loss_g_alpha + 0.5 * loss_g_com + 0.01 * loss_g_d

                loss_G.backward(retain_graph=True)
                self.G_optimizer.step()
                self.G_error_meter.add(loss_G.item())


                #########################################
                # train D
                #########################################
                self.set_requires_grad([self.D], True)
                self.D_optimizer.zero_grad()

                # real [real_img, tri]
                real_d = self.D(input_img)

                target_real_label = t.tensor(1.0)
                target_real = target_real_label.expand_as(real_d).cuda()

                loss_d_real = self.D_criterion(real_d, target_real)

                # fake [fake_img, tri]
                fake_d = self.D(input_d)
                target_fake_label = t.tensor(0.0)

                target_fake = target_fake_label.expand_as(fake_d).cuda()
                loss_d_fake = self.D_criterion(fake_d, target_fake)

                loss_D = 0.5 * (loss_d_real + loss_d_fake)
                loss_D.backward()
                self.D_optimizer.step()
                self.D_error_meter.add(loss_D.item())

                if self.visual:
                    vis.plot('errord', self.D_error_meter.value()[0])
                    vis.plot('errorg', np.array([self.Alpha_loss_meter.value()[0],
                                                 self.Com_loss_meter.value()[0]]),
                             legend=['alpha_loss', 'com_loss'])
                    vis.plot('errorg_d', self.Adv_loss_meter.value()[0])

                self.G_error_meter.reset()
                self.D_error_meter.reset()

                self.Alpha_loss_meter.reset()
                self.Com_loss_meter.reset()
                self.Adv_loss_meter.reset()

            ##############################
            # test
            ##############################

            self.G.eval()
            tester = Tester(net_G=self.G,
                            test_root='/home/zzl/dataset/Combined_Dataset/Test_set/Adobe-licensed_images')
            test_result = tester.test(vis)
            print('sad : {0}, mse : {1}'.format(test_result['sad'], test_result['mse']))
            self.SAD_meter.add(test_result['sad'])
            self.MSE_meter.add(test_result['mse'])

            vis.plot('test_result', np.array([self.SAD_meter.value()[0], self.MSE_meter.value()[0]]),
                     legend=['SAD', 'MSE'])
            if self.save_model:
                t.save(self.D.state_dict(), self.save_dir + '/netD' + '/netD_%s.pth' % epoch)
                t.save(self.G.state_dict(), self.save_dir + '/netG' + '/netG_%s.pth' % epoch)
            self.SAD_meter.reset()
            self.MSE_meter.reset()

        return
Beispiel #26
0
def train(opt):
    vis = Visualizer(opt.visname)

    # model = make_model('inceptionresnetv2',num_classes=5,pretrained=True,input_size=(512,512),dropout_p=0.6)
    model = zin()
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    model = model.cuda()

    loss_fun = t.nn.CrossEntropyLoss()
    # loss_fun = My_loss()
    lr = opt.lr
    # optim = t.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-5)
    optim = t.optim.SGD(model.parameters(),
                        lr=0.00001,
                        weight_decay=5e-5,
                        momentum=0.9)

    loss_meter = meter.AverageValueMeter()
    con_matx = meter.ConfusionMeter(5)
    #     con_matx_temp = meter.ConfusionMeter(5)
    pre_loss = 1e10
    pre_acc = 0

    for epoch in range(100):
        loss_meter.reset()
        con_matx.reset()
        train_data = Eye_img(train=True)
        val_data = Eye_img(train=False)
        train_loader = DataLoader(train_data, opt.batchsize, True)
        val_loader = DataLoader(val_data, opt.batchsize, False)

        for ii, (imgs, data, label) in enumerate(train_loader):
            print('epoch:{}/{}'.format(epoch, ii))
            #     con_matx_temp.reset()
            # print(data.size())
            imgs = imgs.cuda()
            data = data.cuda()
            label = label.cuda()

            optim.zero_grad()
            pre_label = model(imgs, data)

            #     con_matx_temp.add(pre_label.detach(), label.detach())

            #     temp_value = con_matx_temp.value()
            #     temp_kappa = kappa(temp_value,5)
            loss = loss_fun(pre_label, label)
            loss.backward()
            optim.step()

            loss_meter.add(loss.item())
            con_matx.add(pre_label.detach(), label.detach())

            if (ii + 1) % opt.printloss == 0:
                vis.plot('loss', loss_meter.value()[0])
        val_cm, acc, kappa_ = val(model, val_loader)
        if pre_acc < acc:
            pre_acc = acc
            t.save(model, 'zoom3_{}.pth'.format(epoch))
        vis.plot('acc', acc)
        vis.plot('kappa', kappa_)
        vis.log(
            "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}"
            .format(epoch=epoch,
                    loss=loss_meter.value()[0],
                    val_cm=str(val_cm.value()),
                    train_cm=str(con_matx.value()),
                    lr=lr))
Beispiel #27
0
		show_num = 10  # 20个batch显示一次
		if (i + 1) % show_num == 0:
			print('\n ---- batch: %03d of epoch: %03d ----' % (i + 1, epoch + 1))
			print('cla_loss: %f, reg_loss: %f, box_loss: %f, aux_seg_loss: %f, aux_offset_loss: %f' \
			% (batch_cla_loss / show_num, batch_reg_loss / show_num, batch_box_loss / show_num, \
				batch_aux_seg_loss / show_num, batch_aux_offset_loss / show_num))
			print('accuracy: %f' % (batch_correct / float(batch_num)))
			print('true accuracy: %f' % (batch_true_correct / show_num))			
			# vis.plot('cla_loss', (batch_cla_loss / show_num).item())
			# vis.plot('reg_loss', (batch_reg_loss / show_num).item())
			# vis.plot('box_loss', (batch_box_loss / show_num).item())
			# vis.plot('aux_seg_loss', (batch_aux_seg_loss / show_num).item())
			# vis.plot('aux_offset_loss', (batch_aux_offset_loss / show_num).item())
			# vis.plot('accuracy', (batch_correct / float(batch_num)))
			# vis.plot('true accuracy', (batch_true_correct / show_num))
			vis.plot('cla_loss', fromDlpack(to_dlpack(batch_cla_loss / show_num)).item())
			vis.plot('reg_loss', fromDlpack(to_dlpack(batch_reg_loss / show_num)).item())
			vis.plot('box_loss', fromDlpack(to_dlpack(batch_box_loss / show_num)).item())
			vis.plot('objective_loss', fromDlpack(to_dlpack(batch_objective_loss / show_num)).item())			
			vis.plot('aux_seg_loss', fromDlpack(to_dlpack(batch_aux_seg_loss / show_num)).item())
			vis.plot('aux_offset_loss', fromDlpack(to_dlpack(batch_aux_offset_loss / show_num)).item())
			vis.plot('accuracy', fromDlpack(to_dlpack(batch_correct / float(batch_num))).item())
			vis.plot('true accuracy', fromDlpack(to_dlpack(batch_true_correct / show_num)).item())		
			batch_correct = 0.0
			batch_cla_loss = 0.0
			batch_reg_loss = 0.0
			batch_box_loss = 0.0
			batch_aux_seg_loss = 0.0
			batch_aux_offset_loss = 0.0
			batch_num = 0.0
			batch_true_correct = 0.0
Beispiel #28
0
def train(**kwargs):
    # 读取参数赋值
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)
    # 对图片进行操作
    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        # 均值方差
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # ImageFolder 使用pytorch原生的方法读取图片,并进行操作  封装数据集
    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    #数据加载器
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)
    # 把map内容加载到CPU中
    map_location = lambda storage, loc: storage
    # 将预训练的模型都先加载到cpu上
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    # BinaryCrossEntropy二分类交叉熵,常用于二分类问题,当然也可以用于多分类问题
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    # fix_noises是固定值,用来查看每个epoch的变化效果
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    # AverageValueMeter统计任意添加的变量的方差和均值  可视化的仪表盘
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        # 网络转移到GPU
        netd.cuda()
        netg.cuda()
        # 损失函数转移到GPU
        criterion.cuda()
        # 标签转移到GPU
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        # 输入噪声转移到GPU
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu:
                real_img = real_img.cuda()
            # 每d_every个batch训练判别器
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                #一个batchd的真照片判定为1 并反向传播
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                #反向传播
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                # 一个batchd的假照片判定为0 并反向传播
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                #更新可学习参数
                optimizer_d.step()
                # 总误差=识别真实图片误差+假图片误差
                error_d = error_d_fake + error_d_real
                # 将总误差加入仪表板用于可视化显示
                errord_meter.add(error_d.data[0])
            # 每g_every个batch训练生成器
            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 生成器:噪声生成假图片
                fake_img = netg(noises)
                # 判别器:假图片判别份数
                output = netd(fake_img)
                # 尽量让假图片的份数与真标签接近,让判别器分不出来
                error_g = criterion(output, true_labels)
                error_g.backward()
                # 更新参数
                optimizer_g.step()
                # 将误差加入仪表板用于可视化显示
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                # 固定噪声生成假图片
                fix_fake_imgs = netg(fix_noises)
                # 可视化 固定噪声产生的假图片
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                # 可视化一张真图片。作为对比
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                # 可视化仪表盘  判别器误差  生成器误差
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])
        # 每decay_every个epoch之后保存一次模型
        if epoch % opt.decay_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            # 保存判别器  生成器
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            # 清空误差仪表盘
            errord_meter.reset()
            errorg_meter.reset()
            # 重置优化器参数为刚开始的参数
            optimizer_g = t.optim.Adam(netg.parameters(),
                                       opt.lr1,
                                       betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),
                                       opt.lr2,
                                       betas=(opt.beta1, 0.999))