Example #1
0
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的结果
    :param kwargs:
    :return:
    '''
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    netg,netd = NetG(opt).eval(),NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    if opt.gpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        netd.to(device)
        netg.to(device)
        noises.to(device)

    #生成图片
    fake_img = netg(noises)
    scores = netd(fake_img).data

    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for i in indexs:
        result.append(fake_img.data[ii])
    #保存图片
    tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
Example #2
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Example #3
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
Example #4
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    epoch_num = opt.netd_path[opt.netd_path.find("_") + 1:opt.netd_path.rfind(".")]

    randid = random.randint(0, 1000)
    pic_stack_name = "".join(["result_", str(epoch_num), "_", str(randid), ".png"])
    tv.utils.save_image(t.stack(result), pic_stack_name, normalize=True, range=(-1, 1))

    base_dir = "./imgs/tiny_epoch_{0}_{1}".format(epoch_num, randid)
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    for picx in range(len(result)):
        picx_name = "".join([base_dir, "/picid_", str(picx), ".png"])
        tv.utils.save_image(result[picx], picx_name, normalize=True, range=(-1, 1))
Example #5
0
import cv2
import torch as th
from model import NetG, NetD, NetA
from my_dataSet import CASIABDatasetGenerate

netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)

device = th.device("cuda:0")
netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
fineSize = 64

checkpoint = '/home/mg/code/my_GAN_dataSet/snapshots/snapshot_449.t7'
checkpoint = th.load(checkpoint)
neta.load_state_dict(checkpoint['netA'])
netg.load_state_dict(checkpoint['netG'])
netd.load_state_dict(checkpoint['netD'])
neta.eval()
netg.eval()
netd.eval()

angles = [
    '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180'
]

for cond in ['nm-01', 'nm-02', 'nm-03', 'nm-04', 'cl-01', 'cl-02']:
    dataset = CASIABDatasetGenerate(
        data_dir='/home/mg/code/data/GEI_CASIA_B/gei/', cond=cond)
Example #6
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # if opt.vis:
    #     from visualize import Visualizer
    #     vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(root=opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Example #7
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = torch.device('cuda') if opt.gpu else torch.device('cpu')

    hash_id = hash(opt.gen_id) % 51603
    torch.manual_seed(hash_id)

    netg, netd = NetG(opt).eval(), NetD(opt).eval()

    noises = torch.randn(opt.gen_search_num, opt.nz, 1,
                         1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(torch.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    # epoch_num = opt.netd_path[opt.netd_path.find("_") + 1:opt.netd_path.rfind(".")]

    # pic_stack_name = "".join(["result_", str(epoch_num), "_", str(randid), ".png"])
    print(result[0])
    print(result[0].size())
    # np_array = result[0].cpu().numpy()
    #
    # trans_array = np.zeros((96,96,3))
    #
    # trans_array[:,:,0] = np_array[0,:,:]
    # trans_array[:,:,1] = np_array[1,:,:]
    # trans_array[:,:,2] = np_array[2,:,:]
    #
    # print(trans_array.shape)
    #
    # img_ = Image.fromarray(np.uint8(trans_array*255))
    # img_ = img_.resize((224, 224), Image.BILINEAR) #resize to
    # t_img = torch.from_numpy(np.array(img_))
    # tv.utils.save_image(t_img, "PIL_test_img.png", normalize=True, range=(-0.8, 0.8))

    #tv.utils.save_image(torch.stack(result), "save_test.png", normalize=True, range=(-1, 1))

    base_dir = opt.gen_dst
    base_name = opt.base_name
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    for picx in range(len(result)):
        picx_name = "".join([base_dir, base_name, str(picx), ".png"])
        tv.utils.save_image(result[picx],
                            picx_name,
                            normalize=True,
                            range=(-0.8, 0.8))
        img = Image.open(picx_name)
        img = img.resize((188, 188), Image.ANTIALIAS)
        img.save(base_dir + "/" + str(picx) + ".png", "png")
        os.remove(picx_name)
Example #8
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device("cuda") if opt.gpu else t.device("cpu")

    # 数据处理,输出规范为-1~1
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0, noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    # 用来结果的均值和标准差
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能把真图片判别为正
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 使用detach来关闭G求梯度,加速训练
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                # 尽可能把假的图片也判别为1
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            # 可视化

        # 保存模型、图片
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_imgs = netg(fix_noises)
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                "%s%s.png" % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), r"./checkpoints/netd_%s.pth" % epoch)
            t.save(netg.state_dict(), r"./checkpoints/netg_%s.pth" % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Example #9
0
class GANAgent(object):
    def __init__(self,
                 input_size,
                 output_size,
                 num_env,
                 num_step,
                 gamma,
                 lam=0.95,
                 learning_rate=1e-4,
                 ent_coef=0.01,
                 clip_grad_norm=0.5,
                 epoch=3,
                 batch_size=128,
                 ppo_eps=0.1,
                 update_proportion=0.25,
                 use_gae=True,
                 use_cuda=False,
                 use_noisy_net=False,
                 hidden_dim=512):
        self.model = CnnActorCriticNetwork(input_size, output_size,
                                           use_noisy_net)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.update_proportion = update_proportion
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.netG = NetG(z_dim=hidden_dim)  #(input_size, z_dim=hidden_dim)
        self.netD = NetD(z_dim=1)
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)

        self.optimizer_policy = optim.Adam(list(self.model.parameters()),
                                           lr=learning_rate)
        self.optimizer_G = optim.Adam(list(self.netG.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(list(self.netD.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))

        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)

        self.model = self.model.to(self.device)

    def reconstruct(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0)
        return reconstructed.detach().cpu().numpy()

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value_ext, value_int = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value_ext.data.cpu().numpy().squeeze(
        ), value_int.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, obs):
        obs = torch.FloatTensor(obs).to(self.device)
        #embedding = self.vae.representation(obs)
        #reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) # why use index[0]
        reconstructed_img, embedding, reconstructed_embedding = self.netG(obs)

        intrinsic_reward = (embedding - reconstructed_embedding
                            ).pow(2).sum(1) / 2  # Not use reconstructed loss

        return intrinsic_reward.detach().cpu().numpy()

    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        #reconstruction_loss = nn.MSELoss(reduction='none')]
        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        #recon_losses = np.array([])
        #kld_losses = np.array([])
        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (GAN loss)
                #gen_next_state, mu, logvar = self.vae(next_obs_batch[sample_idx])
                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

                if mean_err_d.item() < 1e-5:
                    self.netD.apply(weights_init)
                    print('Reloading net d')
                ############# policy update ###############################################

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer_policy.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy
                loss.backward()
                #global_grad_norm_(list(self.model.parameters())+list(self.vae.parameters())) do we need this step
                #global_grad_norm_(list(self.model.parameter())) or just norm policy
                self.optimizer_poilicy.step()

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch

    def train_just_vae(self, s_batch, next_obs_batch):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))

        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch
Example #10
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Example #11
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    netg, netd = NetG(opt), NetD(opt)
    if opt.netd_path:
        netd.load_state_dict(
            t.load(opt.netd_path, map_location=t.device('cpu')))
    if opt.netg_path:
        netg.load_state_dict(
            t.load(opt.netg_path, map_location=t.device('cpu')))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord = 0
    errorg = 0

    epochs = range(opt.max_epoch)
    for epoch in epochs:
        for ii, (img, _) in tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if (ii + 1) % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                # 把真图片判断为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 把假图片(netg通过噪声生成的图片)判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                errord += (error_d_fake + error_d_real).item()

            if (ii + 1) % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg += error_g.item()

            if opt.vis and (ii + 1) % opt.plot_every == 0:
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord / (opt.plot_every))
                vis.plot('errorg', errorg / (opt.plot_every))
                errord = 0
                errorg = 0

            if (epoch + 1) % opt.save_every == 0:
                tv.utils.save_image(fix_fake_imgs.data[:64],
                                    '%s%s.png' % (opt.save_path, epoch),
                                    normalize=True,
                                    range=(-1, 1))
                t.save(netd.state_dict(),
                       'checkpoints/netd_%s.pth' % (epoch + 1))
                t.save(netg.state_dict(),
                       'checkpoints/netg_%s.pth' % (epoch + 1))
Example #12
0
def train():
    # change opt
    # for k_, v_ in kwargs.items():
    #     setattr(opt, k_, v_)

    device = torch.device('cuda') if torch.cuda.is_available else torch.device(
        'cpu')

    if opt.vis:
        from visualizer import Visualizer
        vis = Visualizer(opt.env)

    # rescale to -1~1
    transform = transforms.Compose([
        transforms.Resize(opt.image_size),
        transforms.CenterCrop(opt.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.ImageFolder(opt.data_path, transform=transform)

    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)

    netd = NetD(opt)
    netg = NetG(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(torch.load(opt.netd_path),
                             map_location=map_location)
    if opt.netg_path:
        netg.load_state_dict(torch.load(opt.netg_path),
                             map_location=map_location)

    if torch.cuda.is_available():
        netd.to(device)
        netg.to(device)

    # 定义优化器和损失
    optimizer_g = torch.optim.Adam(netg.parameters(),
                                   opt.lr1,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr2,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss().to(device)

    # 真label为1, noises是输入噪声
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if torch.cuda.is_available():
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    for epoch in range(opt.max_epoch):
        print("epoch:", epoch, end='\r')
        # sys.stdout.flush()
        for ii, (img, _) in enumerate(dataloader):
            real_img = Variable(img)
            if torch.cuda.is_available():
                real_img = real_img.cuda()

            # 训练判别器, real -> 1, fake -> 0
            if (ii + 1) % opt.d_every == 0:
                # real
                optimizer_d.zero_grad()
                output = netd(real_img)
                # print(output.shape, true_labels.shape)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()
                # fake
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 随机噪声生成假图
                fake_output = netd(fake_img)
                error_d_fake = criterion(fake_output, fake_labels)
                error_d_fake.backward()
                # update optimizer
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            # 训练生成器, 让生成器得到的图片能够被判别器判别为真
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 进行可视化
                # if os.path.exists(opt.debug_file):
                #     import ipdb
                #     ipdb.set_trace()

                fix_fake_img = netg(fix_noises)
                vis.images(
                    fix_fake_img.detach().cpu().numpy()[:opt.batch_size] * 0.5
                    + 0.5,
                    win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:opt.batch_size] * 0.5 +
                           0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_img.data[:opt.batch_size],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Example #13
0
    ndf = 64
    nz = 100


opt = Config()

for gen in range(gen_num):
    Result = random.randint(1000000, 100000000)
    device = t.device('cuda') if opt.gpu else t.device('cpu')
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    fake_img = netg(noises)
    scores = netd(fake_img).detach()
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
        tv.utils.save_image(t.stack(result),
                            str(Result) + opt.gen_img,
                            normalize=True,
                            range=(-1, 1))
Example #14
0
        project="mnist",  # have to change when you want to change project
        job_type="anomaly detection",
    )

state_dict_G = torch.load("weights/G.prm",
                          map_location=lambda storage, loc: storage)
state_dict_D = torch.load("weights/D.prm",
                          map_location=lambda storage, loc: storage)

G_update = NetG(CONFIG)
D_update = NetD(CONFIG)

G_update.load_state_dict(state_dict_G)
D_update.load_state_dict(state_dict_D)

G_update.to(device)
D_update.to(device)

mean = (0.5, )
std = (0.5, )

test_dataset = Dataset(csv_file=CONFIG.test_csv_file,
                       transform=ImageTransform(mean, std))

test_dataloader = DataLoader(test_dataset,
                             batch_size=CONFIG.test_batch_size,
                             shuffle=True)

for sample in test_dataloader:
    test_img = sample["img"]
Example #15
0
def train(**kwargs):
    '''
    训练函数
    :param kwargs: fire传进来的训练参数
    :return:
    '''
    opt.parse(kwargs)
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    if opt.vis:
        vis = Visualizer(opt.env)

    #step1:数据预处理
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                        batch_size=opt.batch_size,
                                        shuffle=True,
                                        num_workers=opt.num_workers,
                                        drop_last=True)

    #step2: 定义网络
    netg,netd = NetG(opt),NetD(opt)
    map_location = lambda storage,loc:storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    #定义优化器和损失函数
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lrG, betas=(0.5, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lrD, betas=(0.5, 0.999))
    criterion = t.nn.BCELoss()

    #真图片label为1,加图片label为0
    #noise为网络的输入
    true_labels = t.ones(opt.batch_size)
    fake_labels = t.zeros(opt.batch_size)
    fix_noises = t.randn(opt.batch_size,opt.nz,1,1)
    noises = t.randn(opt.batch_size,opt.nz,1,1)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
        netd.to(device)
        netg.to(device)
        criterion.to(device)
        true_labels,fake_labels = true_labels.to(device),fake_labels.to(device)
        fix_noises,noises = fix_noises.to(device),noises.to(device)

    epochs = range(140)
    for epoch in iter(epochs):
        for ii,(img,_) in tqdm.tqdm(enumerate(dataloader),total=len(dataloader)):
            if opt.gpu:
                real_img = img.to(device)
            if ii%opt.d_every == 0: #每个batch训练一次鉴别器
                optimizer_d.zero_grad()
                output = netd(real_img) #判断真图片(使其尽可能大)
                error_d_real = criterion(output,true_labels)
                error_d_real.backward()

                ##尽可能把假图片判断为错误
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises).detach() #根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output,fake_labels)
                error_d_fake.backward()

                optimizer_d.step()
                error_d = error_d_fake + error_d_real
                errord_meter.add(error_d.item())

            if ii%opt.g_every == 0: #每5个batch更新一次生成器
                #训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output,true_labels)
                error_g.backward()
                optimizer_g.step()
                errord_meter.add(error_g.item())

            if opt.vis and ii%opt.plot_time == opt.plot_time - 1:
                ##可视化
                fix_fake_img = netg(fix_noises) #使用噪声生成图片
                vis.images(fix_fake_img.data.cpu().numpy()[:64]*0.5+0.5,win = 'fixfake')
                # vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win = 'real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if epoch%opt.decay_every == opt.decay_every-1:
            #保存模型,图片
            tv.utils.save_image(fix_fake_img.data[:64],'%s/new%s.png'%(opt.save_path,epoch),
                                normalize=True,range=(-1,1))
            t.save(netd.state_dict(), 'checkpoints/new_netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/new_netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(), opt.lrG, betas=(0.5, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(), opt.lrD, betas=(0.5, 0.999))