Ejemplo n.º 1
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据处理
    transforms = transforms.Compose([
                                    transforms.Resize(opt.image_size), #重新设置图片大小,opt.image_size默认值为96
                                    transforms.CenterCrop(opt.image_size), #从中心截取大小为opt.image_size的图片
                                    transforms.ToTensor(), #转为Tensor格式,并将值取在[0,1]中
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #标准化,得到在[-1,1]的值
                                    ])
    dataset = datasets.ImageFolder(opt.data_path, transform=transforms) #从data中读取图片,图片类别会设置为文件夹名faces
    dataloader = torch.utils.data.DataLoader(dataset, #然后对得到的图片进行批处理,默认一批为256张图,使用4个进程读取数据
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.num_workers,
                                            drop_last=True  # 什么鬼
                                            )


    # 网络,gnet为生成器,dnet为判别器
    gnet, dnet = GNet(opt), DNet(opt)
    map_location = lambda storage, loc: storage
        if opt.dnet_path:
            dnet.load_state_dict(torch.load(opt.dnet_path, map_location=map_location))
        if opt.gnet_path:
            gnet.load_state_dict(torch.load(opt.gnet_path, map_location=map_location))
Ejemplo n.º 2
0
def generate(**kwargs):#进行验证
    """
    随机生成动漫头像,并根据dnet的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    #device=torch.device('cuda') if opt.gpu else torch.device('cpu')

    gnet, dnet = GNet(opt).eval(), DNet(opt).eval()

    noises = torch.randn(opt.get_search_num, opt.nd, 1, 1).normal_(opt.noise_mean, opt.noise_std)
    #noises = noises.to(device)
    noises = noises.cuda()
    
    map_location = lambda storage, loc: storage
    dnet.load_state_dict(torch.load(opt.dnet_path, map_location=map_location))
    gnet.load_state_dict(torch.load(opt.gnet_path, map_location=map_location))
    dnet.cuda()
    gnet.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = gnet(noises)
    scores = dnet(fake_img).detach()

    # 挑选最好的某几张,默认opt.get_num=64张,并得到其索引
    indexs = scores.topk(opt.get_num)[1]  # tokp()返回元组,一个为分数,一个为索引
    result = []
    for i in indexs:
        result.append(fake_img.data[i])
    # 保存图片
    tv.utils.save_image(torch.stack(result), opt.get_img, normalize=True, range=(-1, 1))
Ejemplo n.º 3
0
def generate():
    opt = Config()
    criterion = nn.BCELoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 训练可能多卡,预测一张就够了,所以有点小不同
    #dnet = torch.load('dnet1.pth').to(device)#可能需要从其他GPU移动到0号,若满足条件则不作为
    #gnet = torch.load('gnet1.pth').to(device)
    dnet = DNet(opt).to(device)
    gnet = GNet(opt).to(device)
    
    state_dict = torch.load('dd.pth')
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    dnet.load_state_dict(new_state_dict)
    
    state_dict = torch.load('gg.pth')
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    gnet.load_state_dict(new_state_dict)

    
    dnet.eval()
    gnet.eval()
    noise = torch.randn(opt.batch_size, opt.nd, 1, 1, device=device)
    #with torch.no_grad():
    fake = gnet(noise)
    output = dnet(fake)
    label = torch.full((opt.batch_size, ), opt.real_label, device=device)
    d_err_fake = criterion(output, label)  # 生成图像的损失;还是tensor
    mean_score = output.mean()  #生成图像的平均得分;还是tensor
    fake_img = vutils.make_grid(fake, normalize=True)

    writer = SummaryWriter(log_dir='generate_rusult')
    writer.add_image('fake_img', fake_img)
    writer.close()
    print('生成图像的平均损失值:%.4f'%d_err_fake.item())
    print('生成图像的平均得分:%.4f'%mean_score.item())
Ejemplo n.º 4
0
def generate(opt, device):

    criterion = nn.BCELoss()

    dnet = DNet(opt).to(device)  # 可能需要从其他GPU移动到0号,若满足条件则不作为
    gnet = GNet(opt).to(device)

    state_dict = torch.load('dnet.pth')
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove 'module.'
        new_state_dict[name] = v
    dnet.load_state_dict(new_state_dict)

    state_dict = torch.load('gnet.pth')
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    gnet.load_state_dict(new_state_dict)

    dnet.eval()
    gnet.eval()

    noise = torch.randn(opt.batch_size, opt.nd, 1, 1, device=device)
    with torch.no_grad():
        fake = gnet(noise)
        output = dnet(fake)
    label = torch.full((opt.batch_size, ), opt.real_label, device=device)
    d_err_fake = criterion(output, label)  # 生成图像的损失;还是tensor
    mean_score = output.mean()  #生成图像的平均得分;还是tensor
    fake_img = vutils.make_grid(fake, normalize=True)

    writer = SummaryWriter(log_dir='generate_result')
    writer.add_image('fake_img', fake_img)
    writer.close()
    print('生成图像的平均损失值:%.4f' % d_err_fake.item())
    print('生成图像的平均得分:%.4f' % mean_score.item())
Ejemplo n.º 5
0
parser.add_argument('--show_img',
                    type=bool,
                    default=False,
                    metavar='S',
                    help='whether or not to show the images')
parser.add_argument('--load',
                    type=str,
                    metavar='M',
                    help='model file to load for evaluating.')
args = parser.parse_args()

# Model
model_gcn = GNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(args.load, map_location=device)
model_gcn.load_state_dict(state_dict)

# Turn batch norm into eval mode
# for child in model_gcn.feat_extr.children():
#     for ii in range(len(child)):
#         if type(child[ii]) == torch.nn.BatchNorm2d:
#             child[ii].track_running_stats = False
model_gcn.eval()

# Cuda
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_gcn.cuda()
    print('Using GPU')
else:
    print('Using CPU')