コード例 #1
0
        cv2.imwrite(frame_path, frame)
        if (frame_count % 50 == 0):
            print("%dth frame has been processed" % frame_count)
        frame_count += 1
except Exception as e:
    print("video has been prcessed to frames")
cap.release()

print("start load models")
# load models
detector = MTCNN()
device = torch.device('cuda')
G = AEI_Net(c_id=512)
G.eval()
G.load_state_dict(
    torch.load('./saved_models/G_latest.pth',
               map_location=torch.device('cpu')))
G = G.cuda()

arcface = Backbone(50, 0.6, 'ir_se').to(device)
arcface.eval()
arcface.load_state_dict(torch.load('./face_modules/model_ir_se50.pth',
                                   map_location=device),
                        strict=False)

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# load source image
コード例 #2
0
    D.train()

    arcface = Backbone(50, 0.6, 'ir_se').to(device)
    arcface.eval()
    arcface.load_state_dict(torch.load('./id_model/model_ir_se50.pth', map_location=device), strict=False)

    opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

    G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
    D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

    try:
        p_G = './saved_mask_models/G_latest.pth'
        p_D = './saved_mask_models/D_latest.pth'
        G.load_state_dict(torch.load(p_G, map_location=torch.device('cpu')), strict=False)
        D.load_state_dict(torch.load(p_D, map_location=torch.device('cpu')), strict=False)
        
        print('p_G : ',p_G)
        print('p_D : ',p_D)
        
    except Exception as e:
        print(e)
    
    dataset = FaceEmbed(['./train_datasets/Foreign-2020-09-06/'], same_prob=0.35)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)


    MSE = torch.nn.MSELoss()
    L1 = torch.nn.L1Loss()
コード例 #3
0
arcface = Backbone(50, 0.6, 'ir_se').to(device)
arcface.eval()
arcface.load_state_dict(torch.load('./face_modules/model_ir_se50.pth',
                                   map_location=device),
                        strict=False)
arcface.requires_grad_(False)

opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

scaler = GradScaler()

try:
    G.load_state_dict(torch.load('./saved_models/AEI_G_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load('./saved_models/AEI_D_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    opt_G.load_state_dict(
        torch.load('./saved_models/AEI_optG_latest.pth',
                   map_location=torch.device('cpu')))
    opt_D.load_state_dict(
        torch.load('./saved_models/AEI_optD_latest.pth',
                   map_location=torch.device('cpu')))
    scaler.load_state_dict(
        torch.load('./saved_models/AEI_scaler_latest.pth',
                   map_location=torch.device('cpu')))
except Exception as e:
    print(e)
コード例 #4
0
ファイル: train_AEI.py プロジェクト: hDluffy/FaceShifter
arcface = Backbone(50, 0.6, 'ir_se').to(device)
arcface.eval()
arcface.load_state_dict(torch.load('./face_modules/model_ir_se50.pth',
                                   map_location=device),
                        strict=False)

opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

try:
    G.load_state_dict(torch.load(os.path.join(args.saved_models,
                                              'G_latest.pth'),
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load(os.path.join(args.saved_models,
                                              'D_latest.pth'),
                                 map_location=torch.device('cpu')),
                      strict=False)
except Exception as e:
    print(e)

# if not fine_tune_with_identity:
# dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.5)
# else:
# dataset = With_Identity('../washed_img/', 0.8)
dataset = FaceEmbed([args.images_path], same_prob=args.same_prob)

dataloader = DataLoader(dataset,
コード例 #5
0
def serve():
    data = {"success": False}
    detector = MTCNN()
    device = torch.device('cuda')
    G = AEI_Net(c_id=512)
    G.eval()
    G.load_state_dict(
        torch.load('./saved_models/G_latest.pth',
                   map_location=torch.device('cpu')))
    G = G.cuda()
    arcface = Backbone(50, 0.6, 'ir_se').to(device)
    arcface.eval()
    arcface.load_state_dict(torch.load('./face_modules/model_ir_se50.pth',
                                       map_location=device),
                            strict=False)

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if flask.request.method == 'POST':
        st = time.time()
        metadata = flask.request.form
        source_image = metadata['source_image']
        target_image = metadata['target_image']
        source_image = base64.b64decode(source_image.encode('utf-8'))
        Xs_raw = cv2.imdecode(np.frombuffer(source_image, np.uint8),
                              cv2.IMREAD_COLOR)
        target_image = base64.b64decode(target_image.encode('utf-8'))
        Xt_raw = cv2.imdecode(np.frombuffer(target_image, np.uint8),
                              cv2.IMREAD_COLOR)
        Xs = detector.align(Image.fromarray(Xs_raw[:, :, ::-1]),
                            crop_size=(256, 256))
        Xs_raw = np.array(Xs)[:, :, ::-1]
        Xs = test_transform(Xs)
        Xs = Xs.unsqueeze(0).cuda()
        with torch.no_grad():
            embeds = arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], (112, 112),
                              mode='bilinear',
                              align_corners=True))
        Xt, trans_inv = detector.align(Image.fromarray(Xt_raw[:, :, ::-1]),
                                       crop_size=(256, 256),
                                       return_trans_inv=True)
        Xt_raw = Xt_raw.astype(np.float) / 255.0
        Xt = test_transform(Xt)
        Xt = Xt.unsqueeze(0).cuda()
        mask = np.zeros([256, 256], dtype=np.float)
        for i in range(256):
            for j in range(256):
                dist = np.sqrt((i - 128)**2 + (j - 128)**2) / 128
                dist = np.minimum(dist, 1)
                mask[i, j] = 1 - dist
        mask = cv2.dilate(mask, None, iterations=20)

        with torch.no_grad():
            Yt, _ = G(Xt, embeds)
            Yt = Yt.squeeze().detach().cpu().numpy().transpose([1, 2, 0
                                                                ]) * 0.5 + 0.5
            Yt = Yt[:, :, ::-1]
            Yt_trans_inv = cv2.warpAffine(
                Yt,
                trans_inv, (np.size(Xt_raw, 1), np.size(Xt_raw, 0)),
                borderValue=(0, 0, 0))
            mask_ = cv2.warpAffine(mask,
                                   trans_inv,
                                   (np.size(Xt_raw, 1), np.size(Xt_raw, 0)),
                                   borderValue=(0, 0, 0))
            mask_ = np.expand_dims(mask_, 2)
            Yt_trans_inv = mask_ * Yt_trans_inv + (1 - mask_) * Xt_raw
        img_data = Yt_trans_inv * 255
        retval, buffer = cv2.imencode('.jpg', img_data)
        pic_str = base64.b64encode(buffer)
        pic_str = pic_str.decode()
        data['success'] = True
        data['image'] = pic_str
        st = time.time() - st
        print(f'process time: {st} sec')
        return flask.jsonify(data)
コード例 #6
0
ファイル: train_HEAR.py プロジェクト: XLEric/FaceSwap
if __name__ == '__main__':
    # vis = visdom.Visdom(server='127.0.0.1', env='faceshifter', port=8099)
    batch_size = 32
    lr = 4e-4
    max_epoch = 2000
    show_step = 10
    save_epoch = 1
    model_save_path = './saved_models/'
    optim_level = 'O0'

    device = torch.device('cuda')

    G = AEI_Net(c_id=512).to(device)
    G.eval()
    G.load_state_dict(torch.load('./saved_models/G_latest-e.pth', map_location=torch.device('cpu')), strict=True)

    net = HearNet()
    net.train()
    net.to(device)

    arcface = Backbone(50, 0.6, 'ir_se').to(device)
    arcface.eval()
    arcface.load_state_dict(torch.load('./id_model/model_ir_se50.pth', map_location=device), strict=False)

    opt = optim.Adam(net.parameters(), lr=lr, betas=(0, 0.999))

    net, opt = amp.initialize(net, opt, opt_level=optim_level)

    try:
        net.load_state_dict(torch.load('./saved_models/HEAR_latest.pth', map_location=torch.device('cpu')), strict=False)
コード例 #7
0
    opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

    G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
    D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

    try:
        if Flag_256:
            Finetune_G_Model = './saved_models/G_latest.pth'
            Finetune_D_Model = './saved_models/D_latest.pth'
        else:
            Finetune_G_Model = './saved_models/G_latest_512.pth'
            Finetune_D_Model = './saved_models/D_latest_512.pth'

        G.load_state_dict(torch.load(Finetune_G_Model,
                                     map_location=torch.device('cpu')),
                          strict=False)
        D.load_state_dict(torch.load(Finetune_D_Model,
                                     map_location=torch.device('cpu')),
                          strict=False)

        print('Finetune_G_Model : ', Finetune_G_Model)
        print('Finetune_D_Model : ', Finetune_D_Model)

    except Exception as e:
        print(e)

    # if not fine_tune_with_identity:
    # dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.5)
    # else:
    # dataset = With_Identity('../washed_img/', 0.8)