def __init__(self, config, json_file):
        # For fast training.
        self.config = config
        cudnn.benchmark = True
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if len(sys.argv) > 1:
            os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = "5"
        global device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.version = config.version
        self.color_version = config.color_version
        self.sf_version = config.sf_version
        self.point_threshold = config.point_threshold

        self.G = ExpressionGenerater()
        self.sfG = ExpressionGenerater()
        self.colorG = NoTrackExpressionGenerater()

        #######   载入预训练网络   ######
        self.load()

        self.G.to(device)
        self.colorG.to(device)
        if config.eval == "1":
            self.G.eval()
        self.colorG.eval()

        self.transform = []
        self.transform.append(T.Resize((224, 224)))
        self.transform.append(T.ToTensor())
        self.transform.append(
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        self.transform = T.Compose(self.transform)

        self.vid_annos = self.load_json(json_file)
def generate_video(config, first_frm_file, json_file):
    # For fast training.
    cudnn.benchmark = True
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if len(sys.argv) > 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "5"
    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    version = config.version

    G = ExpressionGenerater()

    ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/fusion-ckpt-{}".format(
        config.fusion_version)
    FG = FusionGenerater()
    FG_path = os.path.join(ckpt_dir,
                           '{}-G.ckpt'.format(config.fusion_resume_iter))
    FG.load_state_dict(
        torch.load(FG_path, map_location=lambda storage, loc: storage))
    FG.eval()

    #############  process data  ##########
    def crop_face(img, bbox, keypoint):
        flags = list()
        points = list()

        # can not detect face in some images
        if len(bbox) == 0:
            return None

        # draw bbox
        x, y, w, h = [int(v) for v in bbox]
        crop_img = img[y:y + h, x:x + w]
        return crop_img

    def load_json(path):
        with open(path, 'r') as f:
            data = json.load(f)
        print("load %d video annotations totally." % len(data))

        vid_anns = dict()
        for anns in data:
            name = anns['video']
            path = anns['video_path']
            vid_anns[name] = {'path': path}
            for ann in anns['annotations']:
                idx = ann['index']
                keypoints = ann['keypoint']
                bbox = ann['bbox']
                vid_anns[name][idx] = [bbox, keypoints]

        return vid_anns

    def draw_bbox_keypoints(img, bbox, keypoint):
        flags = list()
        points = list()

        # can not detect face in some images
        if len(bbox) == 0:
            return None, None

        points_image = np.zeros_like(img, np.uint8)

        # draw bbox
        # x, y, w, h = [int(v) for v in bbox]
        # cv2.rectangle(img, (x, y), (x+w, y+h), (0, 0, 255), 2)

        # draw points
        for i in range(0, len(keypoint), 3):
            x, y, flag = [int(k) for k in keypoint[i:i + 3]]
            flags.append(flag)
            points.append([x, y])
            if flag == 0:  # keypoint not exist
                continue
            elif flag == 1:  # keypoint exist but invisible
                cv2.circle(points_image, (x, y), 3, (0, 0, 255), -1)
            elif flag == 2:  # keypoint exist and visible
                cv2.circle(points_image, (x, y), 3, (0, 255, 0), -1)
            else:
                raise ValueError("flag of keypoint must be 0, 1, or 2.")

        return crop_face(points_image, bbox,
                         keypoint), crop_face(img, bbox, keypoint)

    def extract_image(img, bbox, keypoint):
        flags = list()
        points = list()

        # can not detect face in some images
        if len(bbox) == 0:
            return None, None, None, None, None

        mask_image = np.zeros_like(img, np.uint8)
        mask = np.zeros_like(img, np.uint8)
        Knockout_image = img.copy()

        # draw bbox
        x, y, w, h = [int(v) for v in bbox]
        print(x, y, w, h)
        cv2.rectangle(mask_image, (x, y), (x + w, y + h), (255, 255, 255),
                      cv2.FILLED)

        cv2.rectangle(Knockout_image, (x, y), (x + w, y + h), (0, 0, 0),
                      cv2.FILLED)

        mask = cv2.rectangle(mask, (x, y), (x + w, y + h), (1, 1, 1),
                             cv2.FILLED)

        onlyface = img * mask

        mask_points = mask_image.copy()
        # draw points
        for i in range(0, len(keypoint), 3):
            x, y, flag = [int(k) for k in keypoint[i:i + 3]]
            flags.append(flag)
            points.append([x, y])
            if flag == 0:  # keypoint not exist
                continue
            elif flag == 1:  # keypoint exist but invisible
                cv2.circle(mask_points, (x, y), 3, (0, 0, 255), -1)
            elif flag == 2:  # keypoint exist and visible
                cv2.circle(mask_points, (x, y), 3, (0, 255, 0), -1)
            else:
                raise ValueError("flag of keypoint must be 0, 1, or 2.")

        return mask_image, Knockout_image, onlyface, img, mask_points

    first_frm = Image.open(first_frm_file)

    vid_annos = load_json(json_file)
    vid_name = os.path.basename(first_frm_file)
    vid_name = "{}.mp4".format(vid_name.split(".")[0])
    anno = vid_annos[vid_name]
    first_bbox, first_keypoint = anno[1]

    _, first_knockout_image, _, first_img, _ = extract_image(
        np.array(first_frm), first_bbox, first_keypoint)

    frm = cv2.imread(first_frm_file)
    _, _, channels = frm.shape
    size = (544, 720)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    sample_dir = "gan-sample-{}".format(version)
    if not os.path.isdir(sample_dir):
        os.mkdir(sample_dir)
    out_path = os.path.join(
        sample_dir,
        'out_{}_{}_{}.mp4'.format(config.version, config.test_image,
                                  config.resume_iter))
    vid_writer = cv2.VideoWriter(out_path, fourcc, 10, size)

    if not os.path.isdir("test_result"):
        os.mkdir("test_result")
    print(len(anno))
    frm_crop = None
    is_first = True
    for idx in range(len(anno) - 1):
        print(idx)
        # if (idx+1) % 5 != 1:
        #     continue

        bbox, keypoint = anno[idx + 1]

        ###### 融合

        cv2.imwrite("test_result/maskface_{}.jpg".format(idx),
                    cv2.cvtColor(np.asarray(first_frm), cv2.COLOR_RGB2BGR))

        mask_image, Knockout_image, onlyface, img, mask_points = extract_image(
            first_frm, bbox, keypoint)

        cv2.imwrite("test_result/mask_{}.jpg".format(idx), mask_points)

        ### TODO: here i have  mask_points, maskface, Knockout_image
        def transform_images(a, b, c):
            transform = []
            transform.append(T.Resize((720, 544)))
            transform.append(T.ToTensor())
            transform.append(
                T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
            transform = T.Compose(transform)

            a_copy = Image.fromarray(a, 'RGB')
            b = Image.fromarray(b, 'RGB')
            c = Image.fromarray(c, 'RGB')

            a_copy = transform(a_copy)
            b = transform(b)
            c = transform(c)

            a_copy = a_copy.unsqueeze(0)
            b = b.unsqueeze(0)
            c = c.unsqueeze(0)

            return a_copy, b, c

        first_img_copy, mask_image, mask = transform_images(
            first_img, mask_image, mask)

        fake_frm = FG(first_img_copy, mask_image, mask)

        fake_frm = denorm(fake_frm.data.cpu())

        toPIL = T.ToPILImage()
        fake_frm = toPIL(fake_frm.squeeze())
        #
        fake_frm = cv2.cvtColor(np.asarray(fake_frm), cv2.COLOR_RGB2BGR)

        cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm)

        vid_writer.write(fake_frm)
    vid_writer.release()
def main(config):
    # For fast training.
    cudnn.benchmark = True
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if len(sys.argv) > 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "5"
    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    version = config.version
    beta1 = 0.5
    beta2 = 0.999

    loader = data_loader.get_loader(
        "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face",
        "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face",
        config)
    G = ExpressionGenerater()
    D = RealFakeDiscriminator()
    #FEN = FeatureExtractNet()
    id_D = IdDiscriminator()
    kp_D = KeypointDiscriminator()
    points_G = LandMarksDetect()

    #######   载入预训练网络   ######
    resume_iter = config.resume_iter
    ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format(
        version)
    log = Logger(os.path.join(ckpt_dir, 'log.txt'))
    if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))):
        G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))
        G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))

        D_path = os.path.join(ckpt_dir, '{}-D.ckpt'.format(resume_iter))
        D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

        IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter))
        id_D.load_state_dict(
            torch.load(IdD_path, map_location=lambda storage, loc: storage))

        kp_D_path = os.path.join(ckpt_dir, '{}-kpD.ckpt'.format(resume_iter))
        kp_D.load_state_dict(
            torch.load(kp_D_path, map_location=lambda storage, loc: storage))

        points_G_path = os.path.join(ckpt_dir,
                                     '{}-pG.ckpt'.format(resume_iter))
        points_G.load_state_dict(
            torch.load(points_G_path,
                       map_location=lambda storage, loc: storage))
    else:
        resume_iter = 0

    #####  训练face2keypoint   ####
    points_G_optimizer = torch.optim.Adam(points_G.parameters(),
                                          lr=0.0001,
                                          betas=(0.5, 0.9))
    kp_D_optimizer = torch.optim.Adam(kp_D.parameters(),
                                      lr=0.0001,
                                      betas=(0.5, 0.9))
    G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9))
    D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.9))
    idD_optimizer = torch.optim.Adam(id_D.parameters(),
                                     lr=0.001,
                                     betas=(0.5, 0.9))
    G.to(device)
    id_D.to(device)
    D.to(device)
    kp_D.to(device)
    points_G.to(device)
    #FEN.to(device)

    #FEN.eval()

    # Start training from scratch or resume training.
    start_iters = resume_iter
    trigger_rec = 1
    data_iter = iter(loader)

    # Start training.
    print('Start training...')
    for i in range(start_iters, 150000):
        # =================================================================================== #
        #                             1. Preprocess input data                                #
        # =================================================================================== #

        #faces, origin_points = next(data_iter)
        #_, target_points = next(data_iter)
        try:
            faces, origin_points = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            faces, origin_points = next(data_iter)
        rand_idx = torch.randperm(origin_points.size(0))
        target_points = origin_points[rand_idx]
        target_faces = faces[rand_idx]

        faces = faces.to(device)
        target_faces = target_faces.to(device)
        origin_points = origin_points.to(device)
        target_points = target_points.to(device)

        # =================================================================================== #
        #                               3. Train the discriminator                            #
        # =================================================================================== #

        # Real fake Dis
        real_loss = -torch.mean(D(faces))  # big for real
        faces_fake = G(faces, target_points)
        fake_loss = torch.mean(D(faces_fake))  # small for fake

        # Compute loss for gradient penalty.
        alpha = torch.rand(faces.size(0), 1, 1, 1).to(device)
        x_hat = (alpha * faces.data +
                 (1 - alpha) * faces_fake.data).requires_grad_(True)
        out_src = D(x_hat)
        d_loss_gp = gradient_penalty(out_src, x_hat)

        lambda_gp = 10
        Dis_loss = real_loss + fake_loss + lambda_gp * d_loss_gp

        D_optimizer.zero_grad()
        Dis_loss.backward()
        D_optimizer.step()

        # ID Dis
        id_real_loss = -torch.mean(id_D(faces, target_faces))  # big for real
        faces_fake = G(faces, target_points)
        id_fake_loss = torch.mean(id_D(faces, faces_fake))  # small for fake

        # Compute loss for gradient penalty.
        alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device)
        x_hat = (alpha * target_faces.data +
                 (1 - alpha) * faces_fake.data).requires_grad_(True)
        out_src = id_D(faces, x_hat)
        id_d_loss_gp = gradient_penalty(out_src, x_hat)

        id_lambda_gp = 10
        id_Dis_loss = id_real_loss + id_fake_loss + id_lambda_gp * id_d_loss_gp

        idD_optimizer.zero_grad()
        id_Dis_loss.backward()
        idD_optimizer.step()

        # Keypoints Dis
        kp_real_loss = -torch.mean(kp_D(target_faces,
                                        target_points))  # big for real
        faces_fake = G(faces, target_points)
        points_fake = points_G(faces_fake)
        kp_fake_loss = torch.mean(kp_D(target_faces,
                                       points_fake))  # small for fake

        # Compute loss for gradient penalty.
        alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device)
        x_hat = (alpha * target_points.data +
                 (1 - alpha) * points_fake.data).requires_grad_(True)
        out_src = kp_D(target_faces, x_hat)
        kp_d_loss_gp = gradient_penalty(out_src, x_hat)

        kp_lambda_gp = 10
        kp_Dis_loss = kp_real_loss + kp_fake_loss + kp_lambda_gp * kp_d_loss_gp

        kp_D_optimizer.zero_grad()
        kp_Dis_loss.backward()
        kp_D_optimizer.step()

        # if (i + 1) % 5 == 0:
        #     print("iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}".format(i,real_loss.item(),
        #                                                                                              fake_loss.item(),
        #                                                                                              lambda_gp * d_loss_gp
        #                                                                                              ))

        # =================================================================================== #
        #                               3. Train the keypointsDetecter                        #
        # =================================================================================== #

        points_detect = points_G(faces)
        detecter_loss_clear = torch.mean(
            torch.abs(points_detect - origin_points))

        detecter_loss = detecter_loss_clear
        points_G_optimizer.zero_grad()
        detecter_loss.backward()
        points_G_optimizer.step()

        # =================================================================================== #
        #                               3. Train the generator                                #
        # =================================================================================== #

        n_critic = 4
        if (i + 1) % n_critic == 0:
            # Original-to-target domain.
            faces_fake = G(faces, target_points)
            predict_points = points_G(faces_fake)
            g_keypoints_loss = -torch.mean(kp_D(target_faces, predict_points))

            g_fake_loss = -torch.mean(D(faces_fake))

            # reconstructs = G(faces_fake, origin_points)
            # g_cycle_loss = torch.mean(torch.abs(reconstructs - faces))
            g_id_loss = -torch.mean(id_D(faces, faces_fake))

            l1_loss = torch.mean(torch.abs(faces_fake - target_faces))

            #feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(target_faces)))

            # 轮流训练
            # if (i+1) % 50 == 0:
            #     trigger_rec = 1 - trigger_rec
            #     print("trigger_rec : ", trigger_rec)
            lambda_rec = config.lambda_rec  # 2 to 4 to 8
            lambda_l1 = config.lambda_l1
            lambda_keypoint = config.lambda_keypoint  # 100 to 50
            lambda_fake = config.lambda_fake
            lambda_id = config.lambda_id
            lambda_feature = config.lambda_feature
            g_loss = lambda_keypoint * g_keypoints_loss + lambda_fake*g_fake_loss \
                      + lambda_id * g_id_loss + lambda_l1 * l1_loss# + lambda_feature*feature_loss

            G_optimizer.zero_grad()
            g_loss.backward()
            G_optimizer.step()

            # Print out training information.
            if (i + 1) % 4 == 0:
                print(
                    "iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}, id_real_loss {:.2}, "
                    "id_fake_loss {:.2}, id_loss_gp {:.2} , g_keypoints_loss {:.2}, "
                    "g_fake_loss {:.2}, g_id_loss {:.2}, L1_loss {:.2}".format(
                        i, real_loss.item(), fake_loss.item(),
                        lambda_gp * d_loss_gp, id_real_loss.item(),
                        id_fake_loss.item(), id_lambda_gp * id_d_loss_gp,
                        lambda_keypoint * g_keypoints_loss.item(),
                        lambda_fake * g_fake_loss.item(),
                        lambda_id * g_id_loss.item(), lambda_l1 * l1_loss))

            sample_dir = "gan-sample-{}".format(version)
            if not os.path.isdir(sample_dir):
                os.mkdir(sample_dir)
            if (i + 1) % 24 == 0:
                with torch.no_grad():
                    target_point = target_points[0]
                    fake_face = faces_fake[0]
                    face = faces[0]
                    #reconstruct = reconstructs[0]
                    predict_point = predict_points[0]

                    sample_path_face = os.path.join(
                        sample_dir, '{}-image-face.jpg'.format(i + 1))
                    save_image(denorm(face.data.cpu()), sample_path_face)

                    # sample_path_rec = os.path.join(sample_dir, '{}-image-reconstruct.jpg'.format(i + 1))
                    # save_image(denorm(reconstruct.data.cpu()), sample_path_rec)

                    sample_path_fake = os.path.join(
                        sample_dir, '{}-image-fake.jpg'.format(i + 1))
                    save_image(denorm(fake_face.data.cpu()), sample_path_fake)

                    sample_path_target = os.path.join(
                        sample_dir, '{}-image-target_point.jpg'.format(i + 1))
                    save_image(denorm(target_point.data.cpu()),
                               sample_path_target)

                    sample_path_predict_points = os.path.join(
                        sample_dir, '{}-image-predict_point.jpg'.format(i + 1))
                    save_image(denorm(predict_point.data.cpu()),
                               sample_path_predict_points)

                    print('Saved real and fake images into {}...'.format(
                        sample_path_face))

        # Save model checkpoints.
        model_save_dir = "ckpt-{}".format(version)

        if (i + 1) % 1000 == 0:
            if not os.path.isdir(model_save_dir):
                os.mkdir(model_save_dir)
            point_G_path = os.path.join(model_save_dir,
                                        '{}-pG.ckpt'.format(i + 1))
            torch.save(points_G.state_dict(), point_G_path)
            kp_D_path = os.path.join(model_save_dir,
                                     '{}-kpD.ckpt'.format(i + 1))
            torch.save(kp_D.state_dict(), kp_D_path)
            G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1))
            torch.save(G.state_dict(), G_path)
            D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i + 1))
            torch.save(D.state_dict(), D_path)
            idD_path = os.path.join(model_save_dir,
                                    '{}-idD.ckpt'.format(i + 1))
            torch.save(id_D.state_dict(), idD_path)
            print('Saved model checkpoints into {}...'.format(model_save_dir))
def main(config):
    # For fast training.
    cudnn.benchmark = True
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if len(sys.argv) > 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "5"
    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    version = config.version
    beta1 = 0.5
    beta2 = 0.999

    loader = data_loader.get_loader(
        "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face",
        "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face",
        config)
    G = ExpressionGenerater()
    FEN = FeatureExtractNet()
    color_D = SNResIdDiscriminator()

    #######   载入预训练网络   ######
    resume_iter = config.resume_iter
    ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format(
        version)
    if not os.path.isdir(ckpt_dir):
        os.mkdir(ckpt_dir)
    log = Logger(os.path.join(ckpt_dir, 'log.txt'))
    if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))):
        G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))
        G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))

        IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter))
        color_D.load_state_dict(
            torch.load(IdD_path, map_location=lambda storage, loc: storage))

    else:
        resume_iter = 0

    #####  训练face2keypoint   ####
    G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9))
    idD_optimizer = torch.optim.Adam(color_D.parameters(),
                                     lr=0.001,
                                     betas=(0.5, 0.9))
    G.to(device)
    color_D.to(device)
    FEN.to(device)

    FEN.eval()

    log.print(config)

    # Start training from scratch or resume training.
    start_iters = resume_iter
    trigger_rec = 1
    data_iter = iter(loader)

    # Start training.
    print('Start training...')
    for i in range(start_iters, 150000):
        # =================================================================================== #
        #                             1. Preprocess input data                                #
        # =================================================================================== #

        #faces, origin_points = next(data_iter)
        #_, target_points = next(data_iter)
        try:
            color_faces, faces = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            color_faces, faces = next(data_iter)
        rand_idx = torch.randperm(faces.size(0))
        condition_faces = faces[rand_idx]

        faces = faces.to(device)
        color_faces = color_faces.to(device)
        condition_faces = condition_faces.to(device)

        # =================================================================================== #
        #                               3. Train the discriminator                            #
        # =================================================================================== #

        # ID Dis
        id_real_loss = torch.mean(
            softplus(-color_D(faces, condition_faces)))  # big for real
        faces_fake = G(color_faces, condition_faces)
        id_fake_loss = torch.mean(
            softplus(color_D(faces_fake, condition_faces)))  # small for fake

        id_Dis_loss = id_real_loss + id_fake_loss

        idD_optimizer.zero_grad()
        id_Dis_loss.backward()
        idD_optimizer.step()

        # =================================================================================== #
        #                               3. Train the generator                                #
        # =================================================================================== #

        n_critic = 1
        if (i + 1) % n_critic == 0:
            # Original-to-target domain.
            faces_fake = G(color_faces, condition_faces)

            g_id_loss = torch.mean(
                softplus(-color_D(faces_fake, condition_faces)))

            l1_loss = torch.mean(
                torch.abs(faces_fake - faces)) + (1 - ssim(faces_fake, faces))

            feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(faces)))

            lambda_l1 = config.lambda_l1
            lambda_id = config.lambda_id
            lambda_feature = config.lambda_feature
            g_loss = lambda_id * g_id_loss + lambda_l1 * l1_loss + lambda_feature * feature_loss

            G_optimizer.zero_grad()
            g_loss.backward()
            G_optimizer.step()

            # Print out training information.
            if (i + 1) % 4 == 0:
                log.print(
                    "iter {} - id_real_loss {:.2}, "
                    "id_fake_loss {:.2} ,  g_id_loss {:.2}, L1_loss {:.2}, feature_loss {:.2}"
                    .format(i, id_real_loss.item(), id_fake_loss.item(),
                            lambda_id * g_id_loss.item(), lambda_l1 * l1_loss,
                            lambda_feature * feature_loss.item()))

            sample_dir = "gan-sample-{}".format(version)
            if not os.path.isdir(sample_dir):
                os.mkdir(sample_dir)
            if (i + 1) % 24 == 0:
                with torch.no_grad():
                    fake_face = faces_fake[0]
                    condition_face = condition_faces[0]
                    color_face = color_faces[0]
                    #reconstruct = reconstructs[0]

                    sample_path_face = os.path.join(
                        sample_dir, '{}-image-face.jpg'.format(i + 1))
                    save_image(denorm(condition_face.data.cpu()),
                               sample_path_face)

                    sample_path_rec = os.path.join(
                        sample_dir, '{}-image-color.jpg'.format(i + 1))
                    save_image(denorm(color_face.data.cpu()), sample_path_rec)

                    sample_path_fake = os.path.join(
                        sample_dir, '{}-image-fake.jpg'.format(i + 1))
                    save_image(denorm(fake_face.data.cpu()), sample_path_fake)

                    print('Saved real and fake images into {}...'.format(
                        sample_path_face))

        # Save model checkpoints.
        model_save_dir = "ckpt-{}".format(version)

        if (i + 1) % 1000 == 0:
            if not os.path.isdir(model_save_dir):
                os.mkdir(model_save_dir)
            G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1))
            torch.save(G.state_dict(), G_path)
            idD_path = os.path.join(model_save_dir,
                                    '{}-idD.ckpt'.format(i + 1))
            torch.save(color_D.state_dict(), idD_path)
            print('Saved model checkpoints into {}...'.format(model_save_dir))
class VideoGenerator():
    def __init__(self, config, json_file):
        # For fast training.
        self.config = config
        cudnn.benchmark = True
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if len(sys.argv) > 1:
            os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = "5"
        global device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.version = config.version
        self.color_version = config.color_version
        self.sf_version = config.sf_version
        self.point_threshold = config.point_threshold

        self.G = ExpressionGenerater()
        self.sfG = ExpressionGenerater()
        self.colorG = NoTrackExpressionGenerater()

        #######   载入预训练网络   ######
        self.load()

        self.G.to(device)
        self.colorG.to(device)
        if config.eval == "1":
            self.G.eval()
        self.colorG.eval()

        self.transform = []
        self.transform.append(T.Resize((224, 224)))
        self.transform.append(T.ToTensor())
        self.transform.append(
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        self.transform = T.Compose(self.transform)

        self.vid_annos = self.load_json(json_file)

    def load(self):
        resume_iter = config.resume_iter
        color_resume_iter = config.color_resume_iter
        ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format(
            self.version)
        sf_ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format(
            self.sf_version)
        color_ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format(
            self.color_version)
        if os.path.exists(
                os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))):
            G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))
            self.G.load_state_dict(
                torch.load(G_path, map_location=lambda storage, loc: storage))
            G_path = os.path.join(sf_ckpt_dir, '{}-G.ckpt'.format(121000))
            self.sfG.load_state_dict(
                torch.load(G_path, map_location=lambda storage, loc: storage))
            cG_path = os.path.join(color_ckpt_dir,
                                   '{}-G.ckpt'.format(color_resume_iter))
            self.colorG.load_state_dict(
                torch.load(cG_path, map_location=lambda storage, loc: storage))
            print("load ckpt")
        else:
            print("found no ckpt")
            return None

    #############  process data  ##########
    def crop_face(self, img, bbox, keypoint):
        flags = list()
        points = list()
        # can not detect face in some images
        if len(bbox) == 0:
            return None
        # draw bbox
        x, y, w, h = [int(v) for v in bbox]
        crop_img = img[y:y + h, x:x + w]
        return crop_img

    def load_json(self, path):
        with open(path, 'r') as f:
            data = json.load(f)
        print("load %d video annotations totally." % len(data))
        vid_anns = dict()
        for anns in data:
            name = anns['video']
            path = anns['video_path']
            vid_anns[name] = {'path': path}
            for ann in anns['annotations']:
                idx = ann['index']
                keypoints = ann['keypoint']
                bbox = ann['bbox']
                vid_anns[name][idx] = [bbox, keypoints]
        return vid_anns

    def draw_rotate_keypoints(self, img, bbox, keypoint, first_bbox,
                              first_keypoint):
        flags = list()
        points = list()
        first_flags = list()
        first_points = list()
        # can not detect face in some images
        if len(bbox) == 0:
            return None, None
        points_image = np.zeros((224, 224, 3), np.uint8)
        # draw bbox
        bx, by, bw, bh = [int(v) for v in bbox]

        for i in range(0, len(keypoint), 3):
            x, y, flag = [int(k) for k in keypoint[i:i + 3]]
            x = int((x - bx) / bw * 224)
            y = int((y - by) / bh * 224)
            flags.append(flag)
            points.append([x, y])
            if flag == 0:  # keypoint not exist
                continue
            elif flag == 1 or flag == 2:  # keypoint exist and visible
                cv2.circle(points_image, (x, y), 2, (0, 255, 0), -1)
            else:
                raise ValueError("flag of keypoint must be 0, 1, or 2.")

        fbx, fby, fbw, fbh = [int(v) for v in first_bbox]

        for i in range(0, len(first_keypoint), 3):
            x, y, flag = [int(k) for k in first_keypoint[i:i + 3]]
            x = int((x - fbx) / fbw * 224)
            y = int((y - fby) / fbh * 224)
            first_flags.append(flag)
            first_points.append([x, y])

        # 脸部对齐
        # 52是左眼的最左边, 61是右眼的最右边, x是宽度方向,y是高度方向
        if not (flags[52] == 0 or flags[61] == 0):
            left_x, left_y = points[52]
            right_x, right_y = points[61]
            if left_x > 224 / 4 or right_x < 224 - 224 / 4:
                return np.asarray(points_image), 0
            deltaH = right_y - left_y
            deltaW = right_x - left_x
            if math.sqrt(deltaW**2 + deltaH**2) < 1:
                return np.asarray(points_image), 0
            angle = math.asin(deltaH / math.sqrt(deltaW**2 + deltaH**2))
            angle = angle / math.pi * 180  # 弧度转角度

            # 计算第一帧的角度
            first_angle = 0
            if not (first_flags[52] == 0 or first_flags[61] == 0):
                left_x, left_y = first_points[52]
                right_x, right_y = first_points[61]
                deltaH = right_y - left_y
                deltaW = right_x - left_x
                if not math.sqrt(deltaW**2 + deltaH**2) < 1:
                    first_angle = math.asin(deltaH /
                                            math.sqrt(deltaW**2 + deltaH**2))
                    first_angle = first_angle / math.pi * 180  # 弧度转角度

            #print("angle", angle)
            #print("first_angle", first_angle)

            angle = angle - first_angle

            if abs(angle) < 5:
                return np.asarray(points_image), 0
            points_image = Image.fromarray(np.uint8(points_image))
            points_image = points_image.rotate(angle)
        else:
            angle = 0

        return np.asarray(points_image), angle

    def draw_bbox_keypoints(self, img, bbox, keypoint):
        flags = list()
        points = list()
        # can not detect face in some images
        if len(bbox) == 0:
            return None, None
        points_image = np.zeros((224, 224, 3), np.uint8)
        # draw bbox
        bx, by, bw, bh = [int(v) for v in bbox]

        for i in range(0, len(keypoint), 3):
            x, y, flag = [int(k) for k in keypoint[i:i + 3]]
            x = int((x - bx) / bw * 224)
            y = int((y - by) / bh * 224)
            flags.append(flag)
            points.append([x, y])
            if flag == 0:  # keypoint not exist
                continue
            elif flag == 1 or flag == 2:  # keypoint exist and visible
                cv2.circle(points_image, (x, y), 2, (0, 255, 0), -1)
            else:
                raise ValueError("flag of keypoint must be 0, 1, or 2.")

        return points_image, self.crop_face(img, bbox, keypoint)

    def extract_image(self, img, bbox):
        # can not detect face in some images
        Knockout_image = img.copy()
        # draw bbox
        x, y, w, h = [int(v) for v in bbox]
        cv2.rectangle(Knockout_image, (x, y), (x + w, y + h), (0, 0, 0),
                      cv2.FILLED)
        return Knockout_image, img

    def generate(self, first_frm_file, bg_file, alpha_file, body_file,
                 first_frm_id, special_vids, hard2middle):
        first_frm = Image.open(first_frm_file)
        bg_body = np.array(first_frm)
        background = np.asarray(Image.open(bg_file))
        # print("shape", background.shape)
        body = np.asarray(Image.open(body_file))
        body_alpha = np.load(alpha_file)
        ## print("shape", body.shape)

        vid_name = os.path.basename(first_frm_file)
        vid_name = "{}.mp4".format(vid_name.split(".")[0])
        anno = self.vid_annos[vid_name]
        first_bbox, first_keypoint = anno[1]

        # 创建结果文件夹
        if not os.path.isdir("test_result"):
            os.mkdir("test_result")
        sample_dir = "test_result/gan-sample-{}-{}".format(
            self.version, self.config.resume_iter)
        if not os.path.isdir(sample_dir):
            os.mkdir(sample_dir)

        # 设置视频格式等
        frm = cv2.imread(first_frm_file)
        heigth, width, channels = frm.shape
        size = (width, heigth)
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out_path = os.path.join(sample_dir, '{}.mp4'.format(first_frm_id))
        vid_writer = cv2.VideoWriter(out_path, fourcc, 25, size)

        first_face_crop = None
        is_first = True
        fake_frm = None
        angle = None
        G = None

        # 特殊视频处理
        if first_frm_id.startswith("12") and int(first_frm_id) in special_vids:
            print("special {}-->{}".format(first_frm_id,
                                           hard2middle[first_frm_id]))
            vid_name = "{}.mp4".format(hard2middle[first_frm_id])
            temp_anno = self.vid_annos[vid_name]
            level2_first_bbox, level2_first_keypoint = temp_anno[1]
            level2_first_frm = os.path.join(
                first_frm_file[:-9], "{}.jpg".format(int(first_frm_id) - 1000))
            level2_first_frm = Image.open(level2_first_frm)
            _, first_face_crop = self.draw_bbox_keypoints(
                np.asarray(level2_first_frm), level2_first_bbox,
                level2_first_keypoint)

        match_cnt = 0
        rotate_match_cnt = 0
        # #########################
        # 主体循环,生成视频
        # #########################
        for idx in tqdm(range(len(anno) - 1)):
            #if idx % 4 == 0:
            #    self.load()
            bbox, keypoint = anno[idx + 1]
            if len(bbox) >= 2:
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])

            if is_first:
                is_first = False
                _, real_first_face_crop = self.draw_bbox_keypoints(
                    np.asarray(first_frm), bbox, keypoint)
                if first_face_crop is None:
                    _, first_face_crop = self.draw_bbox_keypoints(
                        np.asarray(first_frm), bbox, keypoint)
                if len(bbox) == 0:
                    #print("no first bbox")
                    vid_writer.write(frm)
                    is_first = True  # 重置first
                    continue
                if len(first_bbox) == 0:
                    first_bbox = bbox
                    first_keypoint = keypoint
                    #print("first_bbox = bbox")
                vid_writer.write(
                    cv2.cvtColor(np.asarray(first_frm), cv2.COLOR_RGB2BGR))
                continue
            else:
                # 获取第idx帧的关键点图
                draw_kp, angle = self.draw_rotate_keypoints(
                    np.asarray(first_frm), bbox, keypoint, first_bbox,
                    first_keypoint)
                draw_kp_noratota, _ = self.draw_bbox_keypoints(
                    np.asarray(first_frm), bbox, keypoint)

                #print(angle)
                if draw_kp is None or first_face_crop is None:
                    print("no bbox")
                    if first_face_crop is None:
                        vid_writer.write(frm)
                        continue
                    if fake_frm is None:
                        vid_writer.write(frm)
                    else:
                        vid_writer.write(fake_frm)
                    continue

                first_face_crop_arr = Image.fromarray(first_face_crop, 'RGB')
                first_face_tensor = self.transform(first_face_crop_arr)
                first_face_tensor = first_face_tensor.unsqueeze(0)

            first_draw_kp, _ = self.draw_bbox_keypoints(
                np.asarray(first_frm), first_bbox, first_keypoint)
            # print(np.mean(first_draw_kp - draw_kp_noratota))

            #print(angle)
            if angle == 360:
                angle = 0
                G = self.sfG
            else:
                G = self.G

            # 特殊帧特殊处理,若目标帧与原始帧很接近,那么直接拿原始帧
            point_threshold = self.point_threshold
            if np.mean(first_draw_kp - draw_kp_noratota) < point_threshold:
                match_cnt += 1
                ##print(np.mean(first_draw_kp - draw_kp_noratota))
                #print("match ")
                fake_frm = knn_fusion(bg_body, background, body,
                                      body_alpha, first_bbox,
                                      np.asarray(first_face_crop), bbox)
                fake_frm = cv2.cvtColor(np.asarray(fake_frm),
                                        cv2.COLOR_RGB2BGR)
                fake_frm = cv2.resize(fake_frm, (width, heigth))
                vid_writer.write(fake_frm)
                continue

            # 调用模型生成一帧
            def generate_frm(draw_kp, first_face_tensor):
                img_kp = Image.fromarray(draw_kp, 'RGB')
                # img.show()
                key_points = self.transform(img_kp)
                key_points = key_points.unsqueeze(0)
                first_face_tensor = first_face_tensor.to(device)
                key_points = key_points.to(device)
                face_fake = G(first_face_tensor, key_points)
                return face_fake

            def to_PIL(face_fake):
                frm = denorm(face_fake.data.cpu())
                frm = frm[0]
                toPIL = T.ToPILImage()
                frm = toPIL(frm)
                return frm

            # 特殊帧特殊处理,若目标帧与原始帧很接近,那么直接拿原始帧
            if np.mean(first_draw_kp - draw_kp) < point_threshold:
                rotate_match_cnt += 1
                #print(np.mean(first_draw_kp - draw_kp))
                frm = Image.fromarray(np.array(first_face_crop))
                frm = frm.resize(size=(224, 224))
                #print("match rotation")
            else:
                face_fake = generate_frm(draw_kp, first_face_tensor)
                frm = to_PIL(face_fake)

            # 如果有旋转
            if angle != 0:
                face_fake_norotate = generate_frm(draw_kp_noratota,
                                                  first_face_tensor)
                frm = frm.rotate(-angle)
                norotate_frm = to_PIL(face_fake_norotate)

                # 旋转的人脸和未旋转人脸的融合
                frm = np.array(frm)
                frm.flags.writeable = True
                norotate_frm = np.asarray(norotate_frm)
                mask = np.mean(frm, axis=2)
                frm[:, :, 0] = np.where(mask < 0.1, norotate_frm[:, :, 0],
                                        frm[:, :, 0])
                frm[:, :, 1] = np.where(mask < 0.1, norotate_frm[:, :, 1],
                                        frm[:, :, 1])
                frm[:, :, 2] = np.where(mask < 0.1, norotate_frm[:, :, 2],
                                        frm[:, :, 2])
                frm = Image.fromarray(frm)

            #save_image(frm, "test_result/fake_face_{}.jpg".format(idx))

            if config.color_transfer == "1":
                fake_face = cv2.cvtColor(np.asarray(frm), cv2.COLOR_RGB2BGR)
                first_face = cv2.cvtColor(
                    np.asarray(Image.fromarray(first_face_crop, 'RGB')),
                    cv2.COLOR_RGB2BGR)
                frm = color_transfer(first_face, fake_face)
                frm = cv2.cvtColor(np.asarray(frm), cv2.COLOR_BGR2RGB)

            if config.sharpen == "1":
                kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]
                                   ]) * config.sharpen_lambda
                kernel[1, 1] = kernel[1, 1] + 1
                frm = cv2.filter2D(frm, -1, kernel)

            ###### 融合
            #fake_frm = fusion(first_img, first_bbox, np.asarray(frm), bbox)
            fake_frm = knn_fusion(bg_body, background, body, body_alpha,
                                  first_bbox, np.asarray(frm), bbox)
            #
            fake_frm = cv2.cvtColor(np.asarray(fake_frm), cv2.COLOR_RGB2BGR)
            #cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm)
            fake_frm = cv2.resize(fake_frm, (width, heigth))

            # cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm)

            vid_writer.write(fake_frm)
        vid_writer.release()
        print("match_cnt {}, rotate_match_cnt {}".format(
            match_cnt, rotate_match_cnt))