예제 #1
0
파일: shiyan.py 프로젝트: hduyuanfu/GAN
def generate(**kwargs):#进行验证
    """
    随机生成动漫头像,并根据dnet的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    #device=torch.device('cuda') if opt.gpu else torch.device('cpu')

    gnet, dnet = GNet(opt).eval(), DNet(opt).eval()

    noises = torch.randn(opt.get_search_num, opt.nd, 1, 1).normal_(opt.noise_mean, opt.noise_std)
    #noises = noises.to(device)
    noises = noises.cuda()
    
    map_location = lambda storage, loc: storage
    dnet.load_state_dict(torch.load(opt.dnet_path, map_location=map_location))
    gnet.load_state_dict(torch.load(opt.gnet_path, map_location=map_location))
    dnet.cuda()
    gnet.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = gnet(noises)
    scores = dnet(fake_img).detach()

    # 挑选最好的某几张,默认opt.get_num=64张,并得到其索引
    indexs = scores.topk(opt.get_num)[1]  # tokp()返回元组,一个为分数,一个为索引
    result = []
    for i in indexs:
        result.append(fake_img.data[i])
    # 保存图片
    tv.utils.save_image(torch.stack(result), opt.get_img, normalize=True, range=(-1, 1))
예제 #2
0
model_gcn = GNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(args.load, map_location=device)
model_gcn.load_state_dict(state_dict)

# Turn batch norm into eval mode
# for child in model_gcn.feat_extr.children():
#     for ii in range(len(child)):
#         if type(child[ii]) == torch.nn.BatchNorm2d:
#             child[ii].track_running_stats = False
model_gcn.eval()

# Cuda
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_gcn.cuda()
    print('Using GPU')
else:
    print('Using CPU')

# Graph
graph = Graph("./ellipsoid/init_info.pickle")

# Data Loader
folder = CustomDatasetFolder(args.data, extensions=["dat"], print_ref=False)
val_loader = torch.utils.data.DataLoader(folder, batch_size=1, shuffle=True)

tot_loss_norm = 0
tot_loss_unorm = 0
tot_f1_1 = 0
tot_f1_2 = 0
예제 #3
0
                    help='frequency of saving model\'s parameters')

opt = parser.parse_args()

if torch.cuda.is_available() and not opt.cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

Gnet_AB = GNet(opt.G_init_filter, opt.G_depth, opt.G_width)
Gnet_BA = GNet(opt.G_init_filter, opt.G_depth, opt.G_width)
Dnet_A = DNet(opt.D_init_filter, opt.D_depth)
Dnet_B = DNet(opt.D_init_filter, opt.D_depth)

if opt.cuda:
    Gnet_AB.cuda()
    Gnet_BA.cuda()
    Dnet_A.cuda()
    Dnet_B.cuda()

# Weight Initialization from a Gaussian distribution N(0, 0:02)
Gnet_AB.apply(weights_init_normal)
Gnet_BA.apply(weights_init_normal)
Dnet_A.apply(weights_init_normal)
Dnet_B.apply(weights_init_normal)

# Lossess
L_GAN = nn.MSELoss()
L_cyc = nn.L1Loss()
L_identity = nn.L1Loss()