Ejemplo n.º 1
0
def generate(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    fake_img = netg(noises)
    scores = netd(fake_img).data
    ipdb.set_trace()
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])

    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Ejemplo n.º 2
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Ejemplo n.º 3
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
Ejemplo n.º 4
0
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的结果
    :param kwargs:
    :return:
    '''
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    netg,netd = NetG(opt).eval(),NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    if opt.gpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        netd.to(device)
        netg.to(device)
        noises.to(device)

    #生成图片
    fake_img = netg(noises)
    scores = netd(fake_img).data

    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for i in indexs:
        result.append(fake_img.data[ii])
    #保存图片
    tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
Ejemplo n.º 5
0
def test():
    #netG = torch.load('checkpoints/netG_013.pth')
    netG = NetG(conf.batch_size, conf.nz)
    netG.load_state_dict(torch.load('checkpoints/netG_013.pth'))
    #netG.eval()

    for i in range(1):
        noise = torch.randn(conf.batch_size, conf.nz, 1, 1)
        #3import pdb; pdb.set_trace()
        fake_img = netG(noise)  #生成假图
        #utils.img_show(fake_img)

        save_image(fake_img.data, './checkpoints/hehe.png')
Ejemplo n.º 6
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    epoch_num = opt.netd_path[opt.netd_path.find("_") + 1:opt.netd_path.rfind(".")]

    randid = random.randint(0, 1000)
    pic_stack_name = "".join(["result_", str(epoch_num), "_", str(randid), ".png"])
    tv.utils.save_image(t.stack(result), pic_stack_name, normalize=True, range=(-1, 1))

    base_dir = "./imgs/tiny_epoch_{0}_{1}".format(epoch_num, randid)
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    for picx in range(len(result)):
        picx_name = "".join([base_dir, "/picid_", str(picx), ".png"])
        tv.utils.save_image(result[picx], picx_name, normalize=True, range=(-1, 1))
Ejemplo n.º 7
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 将网络模型置为预测模式  不保存中间结果,加速
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    # 初始化gen_search_num张噪声,期望生成gen_search_num张预测图像
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)
    # 将模型参数加载到cpu中
    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    # 模型和输入噪声转到GPU中
    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张  从512章图片中按分数排序,取前64张  的下标
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Ejemplo n.º 8
0
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    # t.manual_seed(100)
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Ejemplo n.º 9
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device("cuda") if opt.gpu else t.device("cpu")

    # 数据处理,输出规范为-1~1
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0, noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    # 用来结果的均值和标准差
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能把真图片判别为正
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 使用detach来关闭G求梯度,加速训练
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                # 尽可能把假的图片也判别为1
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            # 可视化

        # 保存模型、图片
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_imgs = netg(fix_noises)
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                "%s%s.png" % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), r"./checkpoints/netd_%s.pth" % epoch)
            t.save(netg.state_dict(), r"./checkpoints/netg_%s.pth" % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Ejemplo n.º 10
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device=t.device('cuda') if opt.gpu else t.device('cpu')
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()


    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Ejemplo n.º 11
0
def test_network():
    threshold = ct.THRESHOLD
    test_dir = ct.TEST_TXT
    path = os.path.join(ct.BEST_WEIGHT_SAVE_DIR, 'netG.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pretrained_dict = torch.load(
        path, map_location=torch.device(device))['model_state_dict']
    test_data = OSCD_TEST(test_dir)
    test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)
    net = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF, ct.EXTRALAYERS).to(device)
    #     net = nn.DataParallel(net)
    net.load_state_dict(pretrained_dict, False)
    torch.no_grad()
    net.eval()
    i = 0
    TP = 0
    FN = 0
    FP = 0
    TN = 0
    for i, data in enumerate(test_dataloader):
        INPUT_SIZE = [ct.ISIZE, ct.ISIZE]
        x1, x2, gt = data
        x1 = x1.to(device, dtype=torch.float)
        x2 = x2.to(device, dtype=torch.float)
        gt = gt.to(device, dtype=torch.float)
        gt = gt[:, 0, :, :].unsqueeze(1)

        x = torch.cat((x1, x2), 1)
        fake = net(x)

        save_path = os.path.join(ct.IM_SAVE_DIR, 'test_output_images')
        if not os.path.isdir(save_path):
            os.makedirs(save_path)

        if ct.SAVE_TEST_IAMGES:
            vutils.save_image(x1.data,
                              os.path.join(save_path, '%d_x1.png' % i),
                              normalize=True)
            vutils.save_image(x2.data,
                              os.path.join(save_path, '%d_x2.png' % i),
                              normalize=True)
            vutils.save_image(fake.data,
                              os.path.join(save_path, '%d_gt_fake.png' % i),
                              normalize=True)
            vutils.save_image(gt,
                              os.path.join(save_path, '%d_gt.png' % i),
                              normalize=True)

        tp, fp, tn, fn = eva.f1(fake, gt)
        TP += tp
        FN += fn
        TN += tn
        FP += fp
        i += 1
        print('testing {}th images'.format(i))
    iou = TP / (FN + TP + FP + 1e-8)
    precision = TP / (TP + FP + 1e-8)
    oa = (TP + TN) / (TP + FN + TN + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    P = ((TP + FP) * (TP + FN) + (FN + TN) *
         (FP + TN)) / ((TP + TN + FP + FN)**2 + 1e-8)
    Kappa = (oa - P) / (1 - P + 1e-8)
    results = {
        'iou': iou,
        'precision': precision,
        'oa': oa,
        'recall': recall,
        'f1': f1,
        'kappa': Kappa
    }

    with open(os.path.join(ct.OUTPUTS_DIR, 'test_score.txt'), 'a') as f:
        f.write('-----test results on the best model {}-----'.format(
            time.strftime('%Y-%m-%d %H:%M:%S')))
        f.write('\n')
        for key, value in results.items():
            print(key, value)
            f.write('{}: {}'.format(key, value))
            f.write('\n')
Ejemplo n.º 12
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    netg, netd = NetG(opt), NetD(opt)
    if opt.netd_path:
        netd.load_state_dict(
            t.load(opt.netd_path, map_location=t.device('cpu')))
    if opt.netg_path:
        netg.load_state_dict(
            t.load(opt.netg_path, map_location=t.device('cpu')))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noise为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord = 0
    errorg = 0

    epochs = range(opt.max_epoch)
    for epoch in epochs:
        for ii, (img, _) in tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if (ii + 1) % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                # 把真图片判断为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 把假图片(netg通过噪声生成的图片)判断为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                errord += (error_d_fake + error_d_real).item()

            if (ii + 1) % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg += error_g.item()

            if opt.vis and (ii + 1) % opt.plot_every == 0:
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord / (opt.plot_every))
                vis.plot('errorg', errorg / (opt.plot_every))
                errord = 0
                errorg = 0

            if (epoch + 1) % opt.save_every == 0:
                tv.utils.save_image(fix_fake_imgs.data[:64],
                                    '%s%s.png' % (opt.save_path, epoch),
                                    normalize=True,
                                    range=(-1, 1))
                t.save(netd.state_dict(),
                       'checkpoints/netd_%s.pth' % (epoch + 1))
                t.save(netg.state_dict(),
                       'checkpoints/netg_%s.pth' % (epoch + 1))
Ejemplo n.º 13
0
    wandb.init(
        config=CONFIG,
        name=CONFIG.name,
        project="mnist",  # have to change when you want to change project
        job_type="anomaly detection",
    )

state_dict_G = torch.load("weights/G.prm",
                          map_location=lambda storage, loc: storage)
state_dict_D = torch.load("weights/D.prm",
                          map_location=lambda storage, loc: storage)

G_update = NetG(CONFIG)
D_update = NetD(CONFIG)

G_update.load_state_dict(state_dict_G)
D_update.load_state_dict(state_dict_D)

G_update.to(device)
D_update.to(device)

mean = (0.5, )
std = (0.5, )

test_dataset = Dataset(csv_file=CONFIG.test_csv_file,
                       transform=ImageTransform(mean, std))

test_dataloader = DataLoader(test_dataset,
                             batch_size=CONFIG.test_batch_size,
                             shuffle=True)
Ejemplo n.º 14
0
"""""
In order to be able to verify the trained generator model, 
here I will verify the robustness of the model 
through an example-by generating an image data and saving it in a local folder
""" ""

import argparse
from model import NetG
import torch
from torch.autograd import Variable
import torchvision.utils as vutils

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
opt = parser.parse_args()
if __name__ == '__main__':
    print(opt)
    net = NetG(64, 100)
    net.load_state_dict(torch.load("./model_save/xxxx.pth"))
    data = torch.randn(100, 1, 1)
    data = Variable(data.unsqueeze(0))
    output = net(data)
    vutils.save_image(output.data, './test.png', normalize=True)
Ejemplo n.º 15
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = torch.device('cuda') if opt.gpu else torch.device('cpu')

    hash_id = hash(opt.gen_id) % 51603
    torch.manual_seed(hash_id)

    netg, netd = NetG(opt).eval(), NetD(opt).eval()

    noises = torch.randn(opt.gen_search_num, opt.nz, 1,
                         1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(torch.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    # epoch_num = opt.netd_path[opt.netd_path.find("_") + 1:opt.netd_path.rfind(".")]

    # pic_stack_name = "".join(["result_", str(epoch_num), "_", str(randid), ".png"])
    print(result[0])
    print(result[0].size())
    # np_array = result[0].cpu().numpy()
    #
    # trans_array = np.zeros((96,96,3))
    #
    # trans_array[:,:,0] = np_array[0,:,:]
    # trans_array[:,:,1] = np_array[1,:,:]
    # trans_array[:,:,2] = np_array[2,:,:]
    #
    # print(trans_array.shape)
    #
    # img_ = Image.fromarray(np.uint8(trans_array*255))
    # img_ = img_.resize((224, 224), Image.BILINEAR) #resize to
    # t_img = torch.from_numpy(np.array(img_))
    # tv.utils.save_image(t_img, "PIL_test_img.png", normalize=True, range=(-0.8, 0.8))

    #tv.utils.save_image(torch.stack(result), "save_test.png", normalize=True, range=(-1, 1))

    base_dir = opt.gen_dst
    base_name = opt.base_name
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)
    for picx in range(len(result)):
        picx_name = "".join([base_dir, base_name, str(picx), ".png"])
        tv.utils.save_image(result[picx],
                            picx_name,
                            normalize=True,
                            range=(-0.8, 0.8))
        img = Image.open(picx_name)
        img = img.resize((188, 188), Image.ANTIALIAS)
        img.save(base_dir + "/" + str(picx) + ".png", "png")
        os.remove(picx_name)
Ejemplo n.º 16
0
def train_network():

    init_epoch = 0
    best_f1 = 0
    total_steps = 0
    train_dir = ct.TRAIN_TXT
    val_dir = ct.VAL_TXT
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    train_data = OSCD_TRAIN(train_dir)
    train_dataloader = DataLoader(train_data,
                                  batch_size=ct.BATCH_SIZE,
                                  shuffle=True)
    val_data = OSCD_TEST(val_dir)
    val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)
    netg = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF,
                ct.EXTRALAYERS).to(device=device)
    netd = NetD(ct.ISIZE, ct.GT_C, 1, ct.NGF, ct.EXTRALAYERS).to(device=device)
    netg.apply(weights_init)
    netd.apply(weights_init)
    if ct.RESUME:
        assert os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')) \
                and os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')), \
                'There is not found any saved weights'
        print("\nLoading pre-trained networks.")
        init_epoch = torch.load(
            os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'))['epoch']
        netg.load_state_dict(
            torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR,
                                    'current_netG.pth'))['model_state_dict'])
        netd.load_state_dict(
            torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR,
                                    'current_netD.pth'))['model_state_dict'])
        with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt')) as f:
            lines = f.readlines()
            best_f1 = float(lines[-2].strip().split(':')[-1])
        print("\tDone.\n")

    l_adv = l2_loss
    l_con = nn.L1Loss()
    l_enc = l2_loss
    l_bce = nn.BCELoss()
    l_cos = cos_loss
    dice = DiceLoss()
    optimizer_d = optim.Adam(netd.parameters(), lr=ct.LR, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(netg.parameters(), lr=ct.LR, betas=(0.5, 0.999))

    start_time = time.time()
    for epoch in range(init_epoch + 1, ct.EPOCH):
        loss_g = []
        loss_d = []
        netg.train()
        netd.train()
        epoch_iter = 0
        for i, data in enumerate(train_dataloader):
            INPUT_SIZE = [ct.ISIZE, ct.ISIZE]
            x1, x2, gt = data
            x1 = x1.to(device, dtype=torch.float)
            x2 = x2.to(device, dtype=torch.float)
            gt = gt.to(device, dtype=torch.float)
            gt = gt[:, 0, :, :].unsqueeze(1)
            x = torch.cat((x1, x2), 1)

            epoch_iter += ct.BATCH_SIZE
            total_steps += ct.BATCH_SIZE
            real_label = torch.ones(size=(x1.shape[0], ),
                                    dtype=torch.float32,
                                    device=device)
            fake_label = torch.zeros(size=(x1.shape[0], ),
                                     dtype=torch.float32,
                                     device=device)

            #forward

            fake = netg(x)
            pred_real = netd(gt)
            pred_fake = netd(fake).detach()
            err_d_fake = l_bce(pred_fake, fake_label)
            err_g = l_con(fake, gt)
            err_g_total = ct.G_WEIGHT * err_g + ct.D_WEIGHT * err_d_fake

            pred_fake_ = netd(fake.detach())
            err_d_real = l_bce(pred_real, real_label)
            err_d_fake_ = l_bce(pred_fake_, fake_label)
            err_d_total = (err_d_real + err_d_fake_) * 0.5

            #backward
            optimizer_g.zero_grad()
            err_g_total.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_d.zero_grad()
            err_d_total.backward()
            optimizer_d.step()

            errors = utils.get_errors(err_d_total, err_g_total)
            loss_g.append(err_g_total.item())
            loss_d.append(err_d_total.item())

            counter_ratio = float(epoch_iter) / len(train_dataloader.dataset)
            if (i % ct.DISPOLAY_STEP == 0 and i > 0):
                print(
                    'epoch:', epoch, 'iteration:', i,
                    ' G|D loss is {}|{}'.format(np.mean(loss_g[-51:]),
                                                np.mean(loss_d[-51:])))
                if ct.DISPLAY:
                    utils.plot_current_errors(epoch, counter_ratio, errors,
                                              vis)
                    utils.display_current_images(gt.data, fake.data, vis)
        utils.save_current_images(epoch, gt.data, fake.data, ct.IM_SAVE_DIR,
                                  'training_output_images')

        with open(os.path.join(ct.OUTPUTS_DIR, 'train_loss.txt'), 'a') as f:
            f.write(
                'after %s epoch, loss is %g,loss1 is %g,loss2 is %g,loss3 is %g'
                % (epoch, np.mean(loss_g), np.mean(loss_d), np.mean(loss_g),
                   np.mean(loss_d)))
            f.write('\n')
        if not os.path.exists(ct.WEIGHTS_SAVE_DIR):
            os.makedirs(ct.WEIGHTS_SAVE_DIR)
        utils.save_weights(epoch, netg, optimizer_g, ct.WEIGHTS_SAVE_DIR,
                           'netG')
        utils.save_weights(epoch, netd, optimizer_d, ct.WEIGHTS_SAVE_DIR,
                           'netD')
        duration = time.time() - start_time
        print('training duration is %g' % duration)

        #val phase
        print('Validating.................')
        pretrained_dict = torch.load(
            os.path.join(ct.WEIGHTS_SAVE_DIR,
                         'current_netG.pth'))['model_state_dict']
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF,
                   ct.EXTRALAYERS).to(device=device)
        net.load_state_dict(pretrained_dict, False)
        with net.eval() and torch.no_grad():
            TP = 0
            FN = 0
            FP = 0
            TN = 0
            for k, data in enumerate(val_dataloader):
                x1, x2, label = data
                x1 = x1.to(device, dtype=torch.float)
                x2 = x2.to(device, dtype=torch.float)
                label = label.to(device, dtype=torch.float)
                label = label[:, 0, :, :].unsqueeze(1)
                x = torch.cat((x1, x2), 1)
                time_i = time.time()
                v_fake = net(x)

                tp, fp, tn, fn = eva.f1(v_fake, label)
                TP += tp
                FN += fn
                TN += tn
                FP += fp

            precision = TP / (TP + FP + 1e-8)
            oa = (TP + TN) / (TP + FN + TN + FP + 1e-8)
            recall = TP / (TP + FN + 1e-8)
            f1 = 2 * precision * recall / (precision + recall + 1e-8)
            if not os.path.exists(ct.BEST_WEIGHT_SAVE_DIR):
                os.makedirs(ct.BEST_WEIGHT_SAVE_DIR)
            if f1 > best_f1:
                best_f1 = f1
                shutil.copy(
                    os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'),
                    os.path.join(ct.BEST_WEIGHT_SAVE_DIR, 'netG.pth'))
            print('current F1: {}'.format(f1))
            print('best f1: {}'.format(best_f1))
            with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt'), 'a') as f:
                f.write('current epoch:{},current f1:{},best f1:{}'.format(
                    epoch, f1, best_f1))
                f.write('\n')
Ejemplo n.º 17
0
    ndf = 64
    nz = 100


opt = Config()

for gen in range(gen_num):
    Result = random.randint(1000000, 100000000)
    device = t.device('cuda') if opt.gpu else t.device('cpu')
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    fake_img = netg(noises)
    scores = netd(fake_img).detach()
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
        tv.utils.save_image(t.stack(result),
                            str(Result) + opt.gen_img,
                            normalize=True,
                            range=(-1, 1))
Ejemplo n.º 18
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    if opt.vis:
        from visualizer import Visualizer
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu:
                real_img = real_img.cuda()
            if ii % opt.d_every == 0:
                optimizer_d.zero_grad()
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                error_meter.add(error_d.data[0])

            if ii % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img).detach()
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if epoch % opt.decay_every == 0:
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                Normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Ejemplo n.º 19
0
class TACGAN():

    def __init__(self, args):
        self.lr = args.lr
        self.cuda = args.use_cuda
        self.batch_size = args.batch_size
        self.image_size = args.image_size
        self.epochs = args.epochs
        self.data_root = args.data_root
        self.dataset = args.dataset
        self.save_dir = args.save_dir
        self.save_prefix = args.save_prefix
        self.continue_training = args.continue_training
        self.continue_epoch = args.continue_epoch
        self.netG_path = args.netg_path
        self.netD_path = args.netd_path
        self.save_after = args.save_after
        self.trainset_loader = None
        self.evalset_loader = None  
        self.num_workers = args.num_workers
        self.docvec_size = args.docvec_size
        self.n_z = args.n_z # length of the noise vector
        self.nl_d = args.nl_d
        self.nl_g = args.nl_g
        self.nf_g = args.nf_g
        self.nf_d = args.nf_d
        self.bce_loss = nn.BCELoss()
        self.nll_loss = nn.NLLLoss()
        self.mse_loss = nn.MSELoss()
        self.class_filename = args.class_filename
        class_path = os.path.join(self.data_root, self.dataset, self.class_filename)
        with open(class_path) as f:
            self.num_classes = len([l for l in f])
        print(self.num_classes)
        self.netD = NetD(n_cls=self.num_classes, n_t=self.nl_d, n_f=self.nf_d, docvec_size=self.docvec_size)
        self.netG = NetG(n_z=self.n_z, n_l=self.nl_g, n_c=self.nf_g, n_t=self.docvec_size)

        if self.continue_training:
            self.loadCheckpoints()

        # convert to cuda tensors
        if self.cuda and torch.cuda.is_available():
            print('CUDA is enabled')
            self.netD = nn.DataParallel(self.netD).cuda()
            self.netG = nn.DataParallel(self.netG).cuda()
            self.bce_loss = self.bce_loss.cuda()
            self.nll_loss = self.nll_loss.cuda()
            self.mse_loss = self.mse_loss.cuda()

        # optimizers for netD and netG
        self.optimizerD = optim.Adam(params=self.netD.parameters(), lr=self.lr, betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(params=self.netG.parameters(), lr=self.lr, betas=(0.5, 0.999))

        # create dir for saving checkpoints and other results if do not exist
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        if not os.path.exists(os.path.join(self.save_dir,'netd_checkpoints')):
            os.makedirs(os.path.join(self.save_dir,'netd_checkpoints'))
        if not os.path.exists(os.path.join(self.save_dir,'netg_checkpoints')):            
            os.makedirs(os.path.join(self.save_dir,'netg_checkpoints')) 
        if not os.path.exists(os.path.join(self.save_dir,'generated_images')):            
            os.makedirs(os.path.join(self.save_dir,'generated_images'))

    # start training process
    def train(self):
        # write to the log file and print it
        log_msg = '********************************************\n'
        log_msg += '            Training Parameters\n'
        log_msg += 'Dataset:%s\nImage size:%dx%d\n'%(self.dataset, self.image_size, self.image_size)
        log_msg += 'Batch size:%d\n'%(self.batch_size)
        log_msg += 'Number of epochs:%d\nlr:%f\n'%(self.epochs,self.lr)
        log_msg += 'nz:%d\nnl-d:%d\nnl-g:%d\n'%(self.n_z, self.nl_d, self.nl_g)
        log_msg += 'nf-g:%d\nnf-d:%d\n'%(self.nf_g, self.nf_d)  
        log_msg += '********************************************\n\n'
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),'a') as log_file:
            log_file.write(log_msg)
        # load trainset and evalset
        imtext_ds = ImTextDataset(data_dir=self.data_root, dataset=self.dataset, train=True, image_size=self.image_size)
        self.trainset_loader = DataLoader(dataset=imtext_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        print("Dataset loaded successfuly")
        # load checkpoints for continuing training

        # repeat for the number of epochs
        netd_losses = []
        netg_losses = []
        for epoch in range(self.epochs):
            netd_loss, netg_loss = self.trainEpoch(epoch)
            netd_losses.append(netd_loss)
            netg_losses.append(netg_loss)
            self.saveGraph(netd_losses,netg_losses)
            #self.evalEpoch(epoch)
            self.saveCheckpoints(epoch)

    # train epoch
    def trainEpoch(self, epoch):
        self.netD.train() # set to train mode
        self.netG.train() #! set to train mode???
    
        netd_loss_sum = 0
        netg_loss_sum = 0
        start_time = time()
        for i, (images, labels, captions) in enumerate(self.trainset_loader):
            batch_size = images.size(0) # !batch size my be different (from self.batch_size) for the last batch
            images, labels, captions = Variable(images), Variable(labels), Variable(captions) # !labels should be LongTensor
            labels = labels.type(torch.FloatTensor) # convert to FloatTensor (from DoubleTensor)
            lbl_real = Variable(torch.ones(batch_size, 1))
            lbl_fake = Variable(torch.zeros(batch_size, 1))
            noise = Variable(torch.randn(batch_size, self.n_z)) # create random noise
            noise.data.normal_(0,1) # normalize the noise
            rnd_perm1 = torch.randperm(batch_size) # random permutations for different sets of training tuples
            rnd_perm2 = torch.randperm(batch_size)
            rnd_perm3 = torch.randperm(batch_size)
            rnd_perm4 = torch.randperm(batch_size)
            if self.cuda:
                images, labels, captions = images.cuda(), labels.cuda(), captions.cuda()
                lbl_real, lbl_fake = lbl_real.cuda(), lbl_fake.cuda()
                noise = noise.cuda()
                rnd_perm1, rnd_perm2, rnd_perm3, rnd_perm4 = rnd_perm1.cuda(), rnd_perm2.cuda(), rnd_perm3.cuda(), rnd_perm4.cuda()
            
            ############### Update NetD ###############
            self.netD.zero_grad()       
            # train with wrong image, wrong label, real caption
            outD_wrong, outC_wrong = self.netD(images[rnd_perm1], captions[rnd_perm2])
            # lossD_wrong = self.bce_loss(outD_wrong, lbl_fake)
            lossD_wrong = self.bce_loss(outD_wrong, lbl_fake) + self.mse_loss(outD_wrong, lbl_fake)
            lossC_wrong = self.bce_loss(outC_wrong, labels[rnd_perm1])

            # train with real image, real label, real caption
            outD_real, outC_real = self.netD(images, captions)
            #lossD_real = self.bce_loss(outD_real, lbl_real)
            lossD_real = self.bce_loss(outD_real, lbl_real) + self.mse_loss(outD_real, lbl_real)
            lossC_real = self.bce_loss(outC_real, labels)

            # train with fake image, real label, real caption
            fake = self.netG(noise, captions)
            outD_fake, outC_fake = self.netD(fake.detach(), captions[rnd_perm3])
            #lossD_fake = self.bce_loss(outD_fake, lbl_fake)
            lossD_fake = self.bce_loss(outD_fake, lbl_fake) + self.mse_loss(outD_fake, lbl_fake)
            lossC_fake = self.bce_loss(outC_fake, labels[rnd_perm3])
            
            # backward and forwad for NetD
            netD_loss = lossC_wrong+lossC_real+lossC_fake + lossD_wrong+lossD_real+lossD_fake
            netD_loss.backward()
            self.optimizerD.step()      

            ########## Update NetG ##########
            # train with fake data
            self.netG.zero_grad()
            noise.data.normal_(0,1) # normalize the noise vector
            fake = self.netG(noise, captions[rnd_perm4])
            d_fake, c_fake = self.netD(fake, captions[rnd_perm4])
            #lossD_fake_G = self.bce_loss(d_fake, lbl_real)
            lossD_fake_G = self.mse_loss(d_fake, lbl_real)
            lossC_fake_G = self.bce_loss(c_fake, labels[rnd_perm4])
            netG_loss = lossD_fake_G + lossC_fake_G 
            netG_loss.backward()    
            self.optimizerG.step()
            
            netd_loss_sum += netD_loss.item()
            netg_loss_sum += netG_loss.item()
            ### print progress info ###
            print('Epoch %d/%d, %.2f%% completed. Loss_NetD: %.4f, Loss_NetG: %.4f'
                  %(epoch, self.epochs,(float(i)/len(self.trainset_loader))*100, netD_loss.item(), netG_loss.item()))

        end_time = time()
        netd_avg_loss = netd_loss_sum / len(self.trainset_loader)
        netg_avg_loss = netg_loss_sum / len(self.trainset_loader)
        epoch_time = (end_time-start_time)/60
        log_msg = '-------------------------------------------\n'
        log_msg += 'Epoch %d took %.2f minutes\n'%(epoch, epoch_time)
        log_msg += 'NetD average loss: %.4f, NetG average loss: %.4f\n\n' %(netd_avg_loss, netg_avg_loss)
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),'a') as log_file:
            log_file.write(log_msg)
        return netd_avg_loss, netg_avg_loss

    # eval epoch                   
    def evalEpoch(self, epoch):
        #self.netD.eval()
        #self.netG.eval()
        return 0
    
    # draws and saves the loss graph upto the current epoch
    def saveGraph(self, netd_losses, netg_losses):
        plt.plot(netd_losses, color='red', label='NetD Loss')
        plt.plot(netg_losses, color='blue', label='NetG Loss')
        plt.xlabel('epoch')
        plt.ylabel('error')
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.save_dir,'loss_graph.png'))
        plt.close()

    # save after each epoch
    def saveCheckpoints(self, epoch):
        if epoch%self.save_after==0:
            name_netD = "netd_checkpoints/netD_" + self.save_prefix + "_epoch_" + str(epoch) + ".pth"
            name_netG = "netg_checkpoints/netG_" + self.save_prefix + "_epoch_" + str(epoch) + ".pth"
            torch.save(self.netD.module.state_dict(), os.path.join(self.save_dir, name_netD))
            torch.save(self.netG.module.state_dict(), os.path.join(self.save_dir, name_netG))
            print("Checkpoints for epoch %d saved successfuly" %(epoch))

    # SAVE: data parallel model => add .module
    # LOAD: create model and load checkpoints(not add .module) and wrap nn.DataParallel
    # this is for fitting prefix

    # load checkpoints to continue training
    def loadCheckpoints(self):
        name_netD = "netd_checkpoints/netD_" + self.save_prefix + "_epoch_" + str(self.continue_epoch) + ".pth"
        name_netG = "netg_checkpoints/netG_" + self.save_prefix + "_epoch_" + str(self.continue_epoch) + ".pth"
        self.netG.load_state_dict(torch.load(os.path.join(self.save_dir, name_netG)))
        self.netD.load_state_dict(torch.load(os.path.join(self.save_dir, name_netD)))
        print("Checkpoints loaded successfuly")
Ejemplo n.º 20
0
from torch.autograd import Variable
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

nz = 100
cl = 10
bs = 1
transform = transforms.Compose([
    transforms.ToPILImage(),
])
labels_list = [np.random.randint(0, 10)]
noise = Variable(torch.rand(bs, nz))
onehot_labels = np.zeros((bs, cl))
onehot_labels[np.arange(bs), labels_list] = 1
onehot_labels = Variable(
    torch.from_numpy(onehot_labels).type(torch.FloatTensor))

noise = torch.cat([noise, onehot_labels], 1)
noise = noise.view(bs, nz + cl, 1, 1)

netG = NetG(nz + cl)
netG.load_state_dict(
    torch.load('checkpoints/netG__epoch_25.pth',
               map_location=lambda storage, loc: storage))

output = netG(noise)
imgs = transform(output[0].data)
print('class: ' + str(labels_list[0]))
imgs.show()
Ejemplo n.º 21
0
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=opt.batchSize,
    shuffle=True,
    drop_last=True,
)

netG = NetG(opt.ngf, opt.nz).to(device)
netD = NetD(opt.ndf).to(device)
netD.load_state_dict(torch.load("imgs/netD_007.pth"))
netG.load_state_dict(torch.load("imgs/netG_007.pth"))
criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0

for epoch in range(8, opt.epoch + 1):
    for i, (imgs, _) in enumerate(dataloader):
        # 固定生成器G,训练鉴别器D
Ejemplo n.º 22
0
    state_epoch = args.resume_epoch

    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.0, 0.9))
    optimizerD_enc = torch.optim.Adam(netD.feature_encoder.parameters(),
                                      lr=0.0004,
                                      betas=(0.0, 0.9))
    optimizerD_proj = torch.optim.Adam(netD.COND_DNET.parameters(),
                                       lr=0.004,
                                       betas=(0.0, 0.9))

    if state_epoch != 0:
        netG.load_state_dict(
            torch.load('%s/models/netG_%03d.pth' % (output_dir, state_epoch),
                       map_location='cpu'))
        netD.load_state_dict(
            torch.load('%s/models/netD_%03d.pth' % (output_dir, state_epoch),
                       map_location='cpu'))
        netG = netG.cuda()
        netD = netD.cuda()
        optimizerG.load_state_dict(
            torch.load('%s/models/optimizerG.pth' % (output_dir)))
        optimizerD.load_state_dict(
            torch.load('%s/models/optimizerD.pth' % (output_dir)))

    netG.cuda()
    netD.cuda()

    if cfg.B_VALIDATION:
Ejemplo n.º 23
0
class AnoGAN:
    """AnoGAN Class
    """
    def __init__(self, opt):
        # super(AnoGAN, self).__init__(opt, dataloader)

        # Initalize variables.
        self.opt = opt

        self.niter = self.opt.niter
        self.start_iter = 0
        self.netd_niter = 5
        self.test_iter = 100
        self.lr = self.opt.lr
        self.batchsize = {'train': self.opt.batchsize, 'test': 1}

        self.pretrained = False

        self.phase = 'train'
        self.outf = self.opt.experiment_group
        self.algorithm = 'wgan'

        # LOAD DATA SET
        self.dataloader = {
            'train':
            provider('train',
                     opt.category,
                     batch_size=self.batchsize['train'],
                     num_workers=4),
            'test':
            provider('test',
                     opt.category,
                     batch_size=self.batchsize['test'],
                     num_workers=4)
        }

        self.trn_dir = os.path.join(self.outf, self.opt.experiment_name,
                                    'train')
        self.tst_dir = os.path.join(self.outf, self.opt.experiment_name,
                                    'test')

        self.test_img_dir = os.path.join(self.outf, self.opt.experiment_name,
                                         'test', 'images')
        if not os.path.isdir(self.test_img_dir):
            os.makedirs(self.test_img_dir)

        self.best_test_dir = os.path.join(self.outf, self.opt.experiment_name,
                                          'test', 'best_images')
        if not os.path.isdir(self.best_test_dir):
            os.makedirs(self.best_test_dir)

        self.weight_dir = os.path.join(self.trn_dir, 'weights')
        if not os.path.exists(self.weight_dir): os.makedirs(self.weight_dir)

        # -- Misc attributes
        self.epoch = 0

        self.l_con = l1_loss
        self.l_enc = l2_loss

        ##
        # Create and initialize networks.
        self.netg = NetG().cuda()
        self.netd = NetD().cuda()

        # Setup optimizer
        self.optimizer_d = optim.RMSprop(self.netd.parameters(), lr=self.lr)
        self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.lr)

        ##
        self.weight_path = os.path.join(self.outf, self.opt.experiment_name,
                                        'train', 'weights')
        if os.path.exists(self.weight_path) and len(
                os.listdir(self.weight_path)) == 2:
            print("Loading pre-trained networks...\n")
            self.netg.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netD.pth'))['state_dict'])

            self.optimizer_g.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netG.pth'))['optimizer'])
            self.optimizer_d.load_state_dict(
                torch.load(os.path.join(self.weight_path,
                                        'netD.pth'))['optimizer'])

            self.start_iter = torch.load(
                os.path.join(self.weight_path, 'netG.pth'))['epoch']

    ##
    def start(self):
        """ Train the model
        """

        ##
        # TRAIN
        # self.total_steps = 0
        best_criterion = -1  #float('inf')
        best_auc = -1

        # Train for niter epochs.
        # print(">> Training model %s." % self.name)
        for self.epoch in range(self.start_iter, self.niter):
            # Train for one epoch
            mean_wass = self.train()

            (auc, res, best_rec, best_threshold), res_total = self.test()
            message = ''
            # message += 'criterion: (%.3f+%.3f)/2=%.3f ' % (best_rec[0], best_rec[1], res)
            # message += 'best threshold: %.3f ' % best_threshold
            message += 'Wasserstein Distance:%.3d ' % mean_wass
            message += 'AUC: %.3f ' % auc

            print(message)

            torch.save(
                {
                    'epoch': self.epoch + 1,
                    'state_dict': self.netg.state_dict(),
                    'optimizer': self.optimizer_g.state_dict()
                }, '%s/netG.pth' % (self.weight_dir))

            torch.save(
                {
                    'epoch': self.epoch + 1,
                    'state_dict': self.netd.state_dict(),
                    'optimizer': self.optimizer_d.state_dict()
                }, '%s/netD.pth' % (self.weight_dir))

            if auc > best_auc:
                best_auc = auc
                new_message = "******** New optimal found, saving state ********"
                message = message + '\n' + new_message
                print(new_message)

                for img in os.listdir(self.best_test_dir):
                    os.remove(os.path.join(self.best_test_dir, img))

                for img in os.listdir(self.test_img_dir):
                    shutil.copyfile(os.path.join(self.test_img_dir, img),
                                    os.path.join(self.best_test_dir, img))

                shutil.copyfile('%s/netG.pth' % (self.weight_dir),
                                '%s/netg_best.pth' % (self.weight_dir))

            log_name = os.path.join(self.outf, self.opt.experiment_name,
                                    'loss_log.txt')
            message = 'Epoch%3d:' % self.epoch + ' ' + message
            with open(log_name, "a") as log_file:
                if self.epoch == 0:
                    log_file.write('\n\n')
                log_file.write('%s\n' % message)

        print(">> Training %s Done..." % self.opt.experiment_name)

    ##
    def train(self):
        """ Train the model for one epoch.
        """
        print("\n>>> Epoch %d/%d, Running " % (self.epoch + 1, self.niter) +
              self.opt.experiment_name)

        self.netg.train()
        self.netd.train()
        # for p in self.netg.parameters(): p.requires_grad = True

        mean_wass = 0

        tk0 = tqdm(self.dataloader['train'],
                   total=len(self.dataloader['train']))
        for i, itr in enumerate(tk0):
            input, _ = itr
            input = input.cuda()
            wasserstein_d = None
            # if self.algorithm == 'wgan':
            # train NetD
            for _ in range(self.netd_niter):
                # for p in self.netd.parameters(): p.requires_grad = True
                self.optimizer_d.zero_grad()

                # forward_g
                latent_i = torch.rand(self.batchsize['train'], 64, 1, 1).cuda()
                fake = self.netg(latent_i)
                # forward_d
                _, pred_real = self.netd(input)
                _, pred_fake = self.netd(fake)  # .detach() TODO

                # Backward-pass
                wasserstein_d = (pred_fake.mean() - pred_real.mean()) * 1
                wasserstein_d.backward()
                self.optimizer_d.step()

                for p in self.netd.parameters():
                    p.data.clamp_(-0.01, 0.01)  #<<<<<<<

            # train netg
            # for p in self.netd.parameters(): p.requires_grad = False
            self.optimizer_g.zero_grad()
            noise = torch.rand(self.batchsize['train'], 64, 1, 1).cuda()
            fake = self.netg(noise)
            _, pred_fake = self.netd(fake)
            err_g_d = -pred_fake.mean()  # negative

            err_g_d.backward()
            self.optimizer_g.step()

            errors = {
                'loss_netD': wasserstein_d.item(),
                'loss_netG': round(err_g_d.item(), 3),
            }

            mean_wass += wasserstein_d.item()
            tk0.set_postfix(errors)

            if i % 50 == 0:
                img_dir = os.path.join(self.outf, self.opt.experiment_name,
                                       'train', 'images')
                if not os.path.isdir(img_dir):
                    os.makedirs(img_dir)
                self.save_image_cv2(input.data, '%s/reals.png' % img_dir)
                self.save_image_cv2(fake.data,
                                    '%s/fakes%03d.png' % (img_dir, i))

        mean_wass /= len(self.dataloader['train'])
        return mean_wass

    ##
    def test(self):
        """ Test AnoGAN model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        self.netg.eval()
        self.netd.eval()
        # for p in self.netg.parameters(): p.requires_grad = False
        # for p in self.netd.parameters(): p.requires_grad = False

        for img in os.listdir(self.test_img_dir):
            os.remove(os.path.join(self.test_img_dir, img))

        self.phase = 'test'
        meter = Meter_AnoGAN()
        tk1 = tqdm(self.dataloader['test'], total=len(self.dataloader['test']))
        for i, itr in enumerate(tk1):
            input, target = itr
            input = input.cuda()

            latent_i = torch.rand(self.batchsize['test'], 64, 1, 1).cuda()
            latent_i.requires_grad = True

            optimizer_latent = optim.Adam([latent_i], lr=self.lr)
            test_loss = None
            for _ in range(self.test_iter):
                optimizer_latent.zero_grad()
                fake = self.netg(latent_i)
                residual_loss = self.l_con(input, fake)
                latent_o, _ = self.netd(fake)
                discrimination_loss = self.l_enc(latent_i, latent_o)
                alpha = 0.1
                test_loss = (
                    1 - alpha) * residual_loss + alpha * discrimination_loss
                test_loss.backward()
                optimizer_latent.step()

            abnormal_score = test_loss
            meter.update(abnormal_score, target)  #<<<TODO

            # Save test images.
            combine = torch.cat([input.cpu(), fake.cpu()], dim=0)
            self.save_image_cv2(combine,
                                '%s/%05d.jpg' % (self.test_img_dir, i + 1))

        criterion, res_total = meter.get_metrics()

        # rename images
        for i, res in enumerate(res_total):
            os.rename('%s/%05d.jpg' % (self.test_img_dir, i + 1),
                      '%s/%05d_%s.jpg' % (self.test_img_dir, i + 1, res))

        return criterion, res_total

    @staticmethod
    def save_image_cv2(tensor, filename):
        # return
        from torchvision.utils import make_grid
        # tensor = (tensor + 1) / 2
        grid = make_grid(tensor, 8, 2, 0, False, None, False)
        ndarray = grid.mul_(255).clamp_(0, 255).permute(1, 2, 0).to(
            'cpu', torch.uint8).numpy()
        cv2.imwrite(filename, ndarray)
Ejemplo n.º 24
0
def train(**kwargs):
    # 读取参数赋值
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)
    # 对图片进行操作
    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        # 均值方差
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # ImageFolder 使用pytorch原生的方法读取图片,并进行操作  封装数据集
    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    #数据加载器
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)
    # 把map内容加载到CPU中
    map_location = lambda storage, loc: storage
    # 将预训练的模型都先加载到cpu上
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    # BinaryCrossEntropy二分类交叉熵,常用于二分类问题,当然也可以用于多分类问题
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    # fix_noises是固定值,用来查看每个epoch的变化效果
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    # AverageValueMeter统计任意添加的变量的方差和均值  可视化的仪表盘
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        # 网络转移到GPU
        netd.cuda()
        netg.cuda()
        # 损失函数转移到GPU
        criterion.cuda()
        # 标签转移到GPU
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        # 输入噪声转移到GPU
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu:
                real_img = real_img.cuda()
            # 每d_every个batch训练判别器
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                #一个batchd的真照片判定为1 并反向传播
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                #反向传播
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                # 一个batchd的假照片判定为0 并反向传播
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                #更新可学习参数
                optimizer_d.step()
                # 总误差=识别真实图片误差+假图片误差
                error_d = error_d_fake + error_d_real
                # 将总误差加入仪表板用于可视化显示
                errord_meter.add(error_d.data[0])
            # 每g_every个batch训练生成器
            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 生成器:噪声生成假图片
                fake_img = netg(noises)
                # 判别器:假图片判别份数
                output = netd(fake_img)
                # 尽量让假图片的份数与真标签接近,让判别器分不出来
                error_g = criterion(output, true_labels)
                error_g.backward()
                # 更新参数
                optimizer_g.step()
                # 将误差加入仪表板用于可视化显示
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                # 固定噪声生成假图片
                fix_fake_imgs = netg(fix_noises)
                # 可视化 固定噪声产生的假图片
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                # 可视化一张真图片。作为对比
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                # 可视化仪表盘  判别器误差  生成器误差
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])
        # 每decay_every个epoch之后保存一次模型
        if epoch % opt.decay_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            # 保存判别器  生成器
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            # 清空误差仪表盘
            errord_meter.reset()
            errorg_meter.reset()
            # 重置优化器参数为刚开始的参数
            optimizer_g = t.optim.Adam(netg.parameters(),
                                       opt.lr1,
                                       betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),
                                       opt.lr2,
                                       betas=(opt.beta1, 0.999))
Ejemplo n.º 25
0
def main(args):
    # manualSeed to control the noise
    manualSeed = 100
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    with open(args.json_file, 'r') as f:
        dataset_json = json.load(f)

    # load rnn encoder
    text_encoder = RNN_ENCODER(dataset_json['n_words'], nhidden=dataset_json['text_embed_dim'])
    text_encoder_dir = args.rnn_encoder
    state_dict = torch.load(text_encoder_dir, map_location=lambda storage, loc: storage)
    text_encoder.load_state_dict(state_dict)

    # load netG
    state_dict = torch.load(args.model_path, map_location=torch.device('cpu'))
    # netG = NetG(int(dataset_json['n_channels']), int(dataset_json['cond_dim']))
    netG = NetG(64, int(dataset_json['cond_dim']))
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`nvidia
        new_state_dict[name] = v
    model_dict = netG.state_dict()
    pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    netG.load_state_dict(model_dict)

    # use gpu or not, change model to evaluation mode
    if args.use_gpu:
        text_encoder.cuda()
        netG.cuda()
        caption_idx.cuda()
        caption_len.cuda()
        noise.cuda()

    text_encoder.eval()
    netG.eval()

    # generate noise
    num_noise = 100
    noise = torch.FloatTensor(num_noise, 100)

    # cub bird captions
    # caption = 'this small bird has a light yellow breast and brown wings'
    # caption = 'this small bird has a short beak a light gray breast a darker gray crown and black wing tips'
    # caption = 'this small bird has wings that are gray and has a white belly'
    # caption = 'this bird has a yellow throat belly abdomen and sides with lots of brown streaks on them'
    # caption = 'this little bird has a yellow belly and breast with a gray wing with white wingbars'
    # caption = 'this bird has a white belly and breast wit ha blue crown and nape'
    # caption = 'a bird with brown and black wings red crown and throat and the bill is short and pointed'
    # caption = 'this small bird has a yellow crown and a white belly'
    # caption = 'this bird has a blue crown with white throat and brown secondaries'
    # caption = 'this bird has wings that are black and has a white belly'
    # caption = 'a yellow bird has wings with dark stripes and small eyes'
    # caption = 'a black bird has wings with dark stripes and small eyes'
    # caption = 'a red bird has wings with dark stripes and small eyes'
    # caption = 'a white bird has wings with dark stripes and small eyes'
    # caption = 'a blue bird has wings with dark stripes and small eyes'
    # caption = 'a pink bird has wings with dark stripes and small eyes'
    # caption = 'this is a white and grey bird with black wings and a black stripe by its eyes'
    # caption = 'a small bird with an orange bill and grey crown and breast'
    # caption = 'a small bird with black gray and white wingbars'
    # caption = 'this bird is white and light orange in color with a black beak'
    # caption = 'a small sized bird that has tones of brown and a short pointed bill' # beak?

    # MS coco captions
    # caption = 'two men skiing down a snow covered mountain in the evening'
    # caption = 'a man walking down a grass covered mountain'
    # caption = 'a close up of a boat on a field under a sunset'
    # caption = 'a close up of a boat on a field with a clear sky'
    # caption = 'a herd of black and white cattle standing on a field'
    # caption = 'a herd of black and white sheep standing on a field'
    # caption = 'a herd of black and white dogs standing on a field'
    # caption = 'a herd of brown cattle standing on a field'
    # caption = 'a herd of black and white cattle standing in a river'
    # caption = 'some horses in a field of green grass with a sky in the background'
    # caption = 'some horses in a field of yellow grass with a sky in the background'
    caption = 'some horses in a field of green grass with a sunset in the background'

    # convert caption to index
    caption_idx, caption_len = get_caption_idx(dataset_json, caption)
    caption_idx = torch.LongTensor(caption_idx)
    caption_len = torch.LongTensor([caption_len])
    caption_idx = caption_idx.view(1, -1)
    caption_len = caption_len.view(-1)

    # use rnn encoder to get caption embedding
    hidden = text_encoder.init_hidden(1)
    words_embs, sent_emb = text_encoder(caption_idx, caption_len, hidden)

    # generate fake image
    noise.data.normal_(0, 1)
    sent_emb = sent_emb.repeat(num_noise, 1)
    words_embs = words_embs.repeat(num_noise, 1, 1)
    with torch.no_grad():
        fake_imgs, fusion_mask = netG(noise, sent_emb)

        # create path to save image, caption and mask
        cap_number = 10000
        main_path = 'result/mani/cap_%s_0_coco_ch64' % (str(cap_number))
        img_save_path = '%s/image' % main_path
        mask_save_path = '%s/mask_' % main_path
        mkdir_p(img_save_path)
        for i in range(7):
            mkdir_p(mask_save_path + str(i))

        # save caption as image
        ixtoword = {v: k for k, v in dataset_json['word2idx'].items()}
        cap_img = cap2img(ixtoword, caption_idx, caption_len)
        im = cap_img[0].data.cpu().numpy()
        im = (im + 1.0) * 127.5
        im = im.astype(np.uint8)
        im = np.transpose(im, (1, 2, 0))
        im = Image.fromarray(im)
        full_path = '%s/caption.png' % main_path
        im.save(full_path)

        # save generated images and masks
        for i in tqdm(range(num_noise)):
            full_path = '%s/image_%d.png' % (img_save_path, i)
            im = fake_imgs[i].data.cpu().numpy()
            im = (im + 1.0) * 127.5
            im = im.astype(np.uint8)
            im = np.transpose(im, (1, 2, 0))
            im = Image.fromarray(im)
            im.save(full_path)

            for j in range(7):
                full_path = '%s%1d/mask_%d.png' % (mask_save_path, j, i)
                im = fusion_mask[j][i][0].data.cpu().numpy()
                im = im * 255
                im = im.astype(np.uint8)
                im = Image.fromarray(im)
                im.save(full_path)
Ejemplo n.º 26
0
from torch.autograd import Variable
from PIL import Image
import torchvision.transforms as transforms
from model import NetG
import pickle

# convert to PIL Image
trans_toPIL = transforms.ToPILImage()

# load the model
checkpoint_path = 'checkpoints/netG__epoch_100.pth'
n_l = 150
n_z = 100
n_c = 128
netG = NetG(n_z=n_z, n_l=n_l, n_c=n_c)
netG.load_state_dict(
    torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


def generate_from_caption():
    caption_file = "enc_text.pkl"
    # load encoded captions
    train_ids = pickle.load(open(caption_file, 'rb'))
    num_captions = len(train_ids['features'])
    num_images = 2

    # create random noise

    #create random caption
    skv = Variable(torch.randn(num_images, 4800))
    skv.data.normal_(0, 1.1)
Ejemplo n.º 27
0
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers,
                                             drop_last=True)

    netg = NetG(opt)
    netd = NetD(opt)

    # 加载已有的网络参数
    if opt.netd_path:
        print('Loading netd...', end='')
        netd.load_state_dict(torch.load(opt.netd_path))
        print('Successful!')
    if opt.netg_path:
        print('Loading netg...', end='')
        netg.load_state_dict(torch.load(opt.netg_path))
        print('Successful!')

    optimizer_g = torch.optim.Adam(netg.parameters(),
                                   opt.lr_netg,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr_netd,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss()

    # 真图片 label 为 1,假图片为 0
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))
Ejemplo n.º 28
0
def train(**kwargs):
    '''
    训练函数
    :param kwargs: fire传进来的训练参数
    :return:
    '''
    opt.parse(kwargs)
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    if opt.vis:
        vis = Visualizer(opt.env)

    #step1:数据预处理
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                        batch_size=opt.batch_size,
                                        shuffle=True,
                                        num_workers=opt.num_workers,
                                        drop_last=True)

    #step2: 定义网络
    netg,netd = NetG(opt),NetD(opt)
    map_location = lambda storage,loc:storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    #定义优化器和损失函数
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lrG, betas=(0.5, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lrD, betas=(0.5, 0.999))
    criterion = t.nn.BCELoss()

    #真图片label为1,加图片label为0
    #noise为网络的输入
    true_labels = t.ones(opt.batch_size)
    fake_labels = t.zeros(opt.batch_size)
    fix_noises = t.randn(opt.batch_size,opt.nz,1,1)
    noises = t.randn(opt.batch_size,opt.nz,1,1)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
        netd.to(device)
        netg.to(device)
        criterion.to(device)
        true_labels,fake_labels = true_labels.to(device),fake_labels.to(device)
        fix_noises,noises = fix_noises.to(device),noises.to(device)

    epochs = range(140)
    for epoch in iter(epochs):
        for ii,(img,_) in tqdm.tqdm(enumerate(dataloader),total=len(dataloader)):
            if opt.gpu:
                real_img = img.to(device)
            if ii%opt.d_every == 0: #每个batch训练一次鉴别器
                optimizer_d.zero_grad()
                output = netd(real_img) #判断真图片(使其尽可能大)
                error_d_real = criterion(output,true_labels)
                error_d_real.backward()

                ##尽可能把假图片判断为错误
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises).detach() #根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output,fake_labels)
                error_d_fake.backward()

                optimizer_d.step()
                error_d = error_d_fake + error_d_real
                errord_meter.add(error_d.item())

            if ii%opt.g_every == 0: #每5个batch更新一次生成器
                #训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output,true_labels)
                error_g.backward()
                optimizer_g.step()
                errord_meter.add(error_g.item())

            if opt.vis and ii%opt.plot_time == opt.plot_time - 1:
                ##可视化
                fix_fake_img = netg(fix_noises) #使用噪声生成图片
                vis.images(fix_fake_img.data.cpu().numpy()[:64]*0.5+0.5,win = 'fixfake')
                # vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win = 'real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if epoch%opt.decay_every == opt.decay_every-1:
            #保存模型,图片
            tv.utils.save_image(fix_fake_img.data[:64],'%s/new%s.png'%(opt.save_path,epoch),
                                normalize=True,range=(-1,1))
            t.save(netd.state_dict(), 'checkpoints/new_netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/new_netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(), opt.lrG, betas=(0.5, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(), opt.lrD, betas=(0.5, 0.999))
Ejemplo n.º 29
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.gpu else t.device('cpu')
    # if opt.vis:
    #     from visualize import Visualizer
    #     vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(root=opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
Ejemplo n.º 30
0
z = Variable(z)

netG = netG.cuda()
netD = netD.cuda()
z = z.cuda()
realData = realData.cuda()
one = one.cuda()
mone = mone.cuda()

# setup optimizer
optimizerD = optim.RMSprop(netD.parameters(), lr=0.00005)  # 0.00005
optimizerG = optim.RMSprop(netG.parameters(), lr=0.00005)
if opt.ep != -1:
    netD.load_state_dict(
        torch.load(checkRoot + '/netD_epoch_' + str(opt.ep) + '.pth'))
    netG.load_state_dict(
        torch.load(checkRoot + '/netG_epoch_' + str(opt.ep) + '.pth'))
# train
ig = 0
for it in np.arange(iterNum) + opt.ep + 1:
    dataIter = iter(dataLoader)
    ib = 0
    while ib < len(dataLoader):
        ############################
        # (1) Update D network
        ###########################
        # train the discriminator Diters times
        if ig < 25 or ig % 500 == 0:
            Diters = 10
        else:
            Diters = 5
        id = 0
Ejemplo n.º 31
0
from my_dataSet import CASIABDatasetGenerate

netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)

device = th.device("cuda:0")
netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
fineSize = 64

checkpoint = '/home/mg/code/my_GAN_dataSet/snapshots/snapshot_449.t7'
checkpoint = th.load(checkpoint)
neta.load_state_dict(checkpoint['netA'])
netg.load_state_dict(checkpoint['netG'])
netd.load_state_dict(checkpoint['netD'])
neta.eval()
netg.eval()
netd.eval()

angles = [
    '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180'
]

for cond in ['nm-01', 'nm-02', 'nm-03', 'nm-04', 'cl-01', 'cl-02']:
    dataset = CASIABDatasetGenerate(
        data_dir='/home/mg/code/data/GEI_CASIA_B/gei/', cond=cond)
    for i in range(1, 125):
        ass_label, img = dataset.getbatch(i, 11)
        img = img.to(device).to(th.float32)
Ejemplo n.º 32
0
def train():
    # change opt
    # for k_, v_ in kwargs.items():
    #     setattr(opt, k_, v_)

    device = torch.device('cuda') if torch.cuda.is_available else torch.device(
        'cpu')

    if opt.vis:
        from visualizer import Visualizer
        vis = Visualizer(opt.env)

    # rescale to -1~1
    transform = transforms.Compose([
        transforms.Resize(opt.image_size),
        transforms.CenterCrop(opt.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.ImageFolder(opt.data_path, transform=transform)

    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)

    netd = NetD(opt)
    netg = NetG(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(torch.load(opt.netd_path),
                             map_location=map_location)
    if opt.netg_path:
        netg.load_state_dict(torch.load(opt.netg_path),
                             map_location=map_location)

    if torch.cuda.is_available():
        netd.to(device)
        netg.to(device)

    # 定义优化器和损失
    optimizer_g = torch.optim.Adam(netg.parameters(),
                                   opt.lr1,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr2,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss().to(device)

    # 真label为1, noises是输入噪声
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if torch.cuda.is_available():
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    for epoch in range(opt.max_epoch):
        print("epoch:", epoch, end='\r')
        # sys.stdout.flush()
        for ii, (img, _) in enumerate(dataloader):
            real_img = Variable(img)
            if torch.cuda.is_available():
                real_img = real_img.cuda()

            # 训练判别器, real -> 1, fake -> 0
            if (ii + 1) % opt.d_every == 0:
                # real
                optimizer_d.zero_grad()
                output = netd(real_img)
                # print(output.shape, true_labels.shape)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()
                # fake
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 随机噪声生成假图
                fake_output = netd(fake_img)
                error_d_fake = criterion(fake_output, fake_labels)
                error_d_fake.backward()
                # update optimizer
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            # 训练生成器, 让生成器得到的图片能够被判别器判别为真
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 进行可视化
                # if os.path.exists(opt.debug_file):
                #     import ipdb
                #     ipdb.set_trace()

                fix_fake_img = netg(fix_noises)
                vis.images(
                    fix_fake_img.detach().cpu().numpy()[:opt.batch_size] * 0.5
                    + 0.5,
                    win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:opt.batch_size] * 0.5 +
                           0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_img.data[:opt.batch_size],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()