def facenet_distance(img1, img2, percentage=False, embedding=False):
    """
    Outputs distance between 0 and 2 with 0 is identical and 2 is most different
    :param embedding: bool - return embedding
    :param percentage: bool - return percentage
    :param img1: img1 tensor
    :param img2: img2 tensor
    :return: dist between 0 (identical) and 2 (most dissimilar), + embedding/percentange if True
    """
    img1 = mtcnn(Image.fromarray(tensor2im(img1)))
    img2 = mtcnn(Image.fromarray(tensor2im(img2)))

    if img1 is None or img2 is None:
        dist = 2
        embedding_list = [torch.zeros(512), torch.zeros(512)]
    else:
        img_stack = torch.stack([img1, img2]).to(device)
        embedding_list = resnet(img_stack)  # .detach().cpu()
        embedding_dist = embedding_list[0] - embedding_list[1]
        dist = embedding_dist.norm().item()

    if percentage:
        return dist, ((2 - dist) * 50)
    elif embedding:
        return dist, embedding_list
    else:
        return dist
def run_segmentation_model(model, opt, root_folder, warp=False):
    print("\n---- Running Face2Mask Model ----\n")

    dataset = create_dataset(opt)
    for i, data in tqdm(enumerate(dataset)):
        if i >= opt.num_test:
            break

        # run model and get processed image
        model.set_input(data)
        model.test()

        # get mask
        visuals_mask = model.get_current_visuals()
        mask = tensor2im(visuals_mask['fake_B'])
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = clean_mask(mask)
        mask = Image.fromarray(mask)
        mask = make_color_transparent(mask, (0, 0, 0), tolerance=50)

        # open original
        img_path = model.get_image_paths()[0]
        img_name = img_path.split("/")[-1][:-4]
        image = Image.open(img_path)

        # add background to mask
        overlayed_mask = overlay_two_images(image, mask)

        # save file
        overlayed_mask.save(f'{root_folder}/{img_name}.png')

    print("SAVED GENERATED MASKS\n")
def run_blending_model(model,
                       opt,
                       root_folder,
                       result_folder,
                       blend_back=False):
    print("\n---- Running Face2Face Bleding Model ---- \n")

    if blend_back:
        with open(f'{root_folder}/{generated_folder}/coordinates.txt',
                  'r') as f:
            reinsert_rect = json.load(f)
        Path(f"{root_folder}/{blended_reinserted_folder}/{result_folder}"
             ).mkdir(parents=True, exist_ok=True)

    dataset = create_dataset(opt)
    for i, data in tqdm(enumerate(dataset)):
        if i >= opt.num_test:
            break

        # test model
        model.set_input(data)
        model.test()

        # get visuals
        visuals_face = model.get_current_visuals()
        img_path = model.get_image_paths()[0]
        file_name = img_path.split('/')[-1]

        # Tensor to image
        generated_image = tensor2im(visuals_face['fake_B'])
        generated_image = cv2.cvtColor(generated_image, cv2.COLOR_RGB2BGR)

        # get key and file name
        key_name = file_name[:-4]
        raw_file_name = "_".join(file_name.split("_")[1:])

        # get original image
        out = _get_maybe_modified_file(raw_file_name, root_folder)
        # save generated image
        cv2.imwrite(
            f"{root_folder}/{blended_folder}/{result_folder}/{file_name}",
            generated_image)

        if blend_back:
            # get face bounding box
            x, y, x_end, y_end = reinsert_rect[key_name]["reinsert_range"]
            shape = out.shape

            # reinsert blended image
            generated_image = cv2.resize(
                generated_image, (x_end - x + min(shape[1] - x_end, 0),
                                  y_end - y + min(shape[0] - y_end, 0)))
            out[y:y_end, x:x_end, :] = generated_image

            # save
            cv2.imwrite(
                f"{root_folder}/{blended_reinserted_folder}/{result_folder}/{file_name}",
                out)

    print("SAVED BLENDED FACE\n")
def run_generation_model(model,
                         opt,
                         root_folder,
                         smoothEdge=20,
                         min_width=30,
                         margin=0):
    print("\n---- Running Mask2Face Model ---- \n")

    with open(f'{root_folder}/{test_folder}/coordinates.txt', 'r') as f:
        face_extraction_data = json.load(f)

    # logging
    num_too_small = 0
    num_replaced = 0
    insert_data = {}

    dataset = create_dataset(opt)
    for i, data in tqdm(enumerate(dataset)):
        if i >= opt.num_test:
            break

        # test model
        model.set_input(data)
        model.test()

        # get visuals
        visuals_face = model.get_current_visuals()
        img_path = model.get_image_paths()[0]
        key_name = img_path.split('/')[-1][:-4]
        idx = key_name.index("_")
        raw_file_name = key_name[idx + 1:] + ".png"

        # Tensor to image
        generated_image = tensor2im(visuals_face['fake_B'])
        generated_image = cv2.cvtColor(generated_image, cv2.COLOR_RGB2BGR)

        # get face bounding box
        try:
            x, y, w, h = face_extraction_data[key_name]["rect_cv"]
        except KeyError:
            # this handles files that were not deleted and are not longer in the test set
            print(
                f"{key_name} key not found in dict face_extraction_data! Skipping."
            )
            continue

        # only continue if width of face that got reconstructed is larger that 50 pixels
        if w > min_width:
            out = _get_maybe_modified_file(raw_file_name, root_folder)

            # save generated image
            cv2.imwrite(f"{root_folder}/{generated_folder}/{key_name}.png",
                        generated_image)

            # load keypoint data
            keypoints = face_extraction_data[key_name]["keypoints"]
            alignment_params = face_extraction_data[key_name][
                "alignment_params"]

            out, capture, out_margin, capture_margin, reinsert_range = reinsert_aligned_into_image(
                generated_image,
                out,
                alignment_params,
                keypoints,
                smoothEdge=smoothEdge,
                margin=margin,
                clean_merge=True)

            insert_data[key_name] = {'reinsert_range': reinsert_range}

            cv2.imwrite(
                f"{root_folder}/{generated_reinserted_folder}/{raw_file_name}",
                out)
            cv2.imwrite(f"{root_folder}/{to_blend_folder}/{key_name}.png",
                        out_margin)

            num_replaced += 1
        else:
            num_too_small += 1

    with open(f"{root_folder}/{generated_folder}/coordinates.txt",
              'w') as outfile:
        json.dump(insert_data, outfile)

    print(
        f"In total {num_replaced} faces replaced. "
        f"{num_too_small} incedents of faces with width < {min_width} were NOT replaced. \n"
    )

    print("SAVED GENERATED FACE\n")
    def align_fake(self, margin=40, alignUnaligned=True):

        # get params
        desiredLeftEye = [
            float(self.alignment_params["desiredLeftEye"][0]),
            float(self.alignment_params["desiredLeftEye"][1])
        ]
        rotation_point = self.alignment_params["eyesCenter"]
        angle = -self.alignment_params["angle"]
        h, w = self.fake_B.shape[2:]
        # get original positions
        m1 = round(w * 0.5)
        m2 = round(desiredLeftEye[1] * w)
        # define the scale factor
        scale = 1 / self.alignment_params["scale"]
        width = int(self.alignment_params["shape"][0])
        long_edge_size = width / abs(np.cos(np.deg2rad(angle)))
        w_original = int(scale * long_edge_size)
        h_original = int(scale * long_edge_size)
        # get offset
        tX = w_original * 0.5
        tY = h_original * desiredLeftEye[1]
        # get rotation center
        center = torch.ones(1, 2)
        center[..., 0] = m1
        center[..., 1] = m2
        # compute the transformation matrix
        M: torch.tensor = kornia.get_rotation_matrix2d(center, angle,
                                                       scale).to(self.device)
        M[0, 0, 2] += (tX - m1)
        M[0, 1, 2] += (tY - m2)

        # get insertion point
        x_start = int(rotation_point[0] - (0.5 * w_original))
        y_start = int(rotation_point[1] - (desiredLeftEye[1] * h_original))
        # _, _, h_tensor, w_tensor = self.real_B_unaligned_full.shape

        # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # #
        # get safe margin
        h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:]
        margin = max(
            min(
                y_start - max(0, y_start - margin),
                x_start - max(0, x_start - margin),
                min(y_start + h_original + margin, h_size_tensor) - y_start -
                h_original,
                min(x_start + w_original + margin, w_size_tensor) - x_start -
                w_original,
            ), 0)
        # get face + margin unaligned space
        self.real_B_aligned_margin = self.real_B_unaligned_full[:, :, y_start -
                                                                margin:
                                                                y_start +
                                                                h_original +
                                                                margin,
                                                                x_start -
                                                                margin:
                                                                x_start +
                                                                w_original +
                                                                margin]
        # invert matrix
        M_inverse = kornia.invert_affine_transform(M)
        # update output size to fit the 256 + scaled margin
        old_size = self.real_B_aligned_margin.shape[2]
        new_size = old_size + 2 * round(float(margin * scale))

        _, _, h_tensor, w_tensor = self.real_B_aligned_margin.shape
        self.real_B_aligned_margin = kornia.warp_affine(
            self.real_B_aligned_margin, M_inverse, dsize=(new_size, new_size))
        # padding_mode="reflection")
        self.fake_B_aligned_margin = self.real_B_aligned_margin.clone(
        ).requires_grad_(True)

        # update margin as we now scale the image!
        # update start point
        start = round(float(margin * scale * new_size / old_size))
        print(start)

        # point = torch.tensor([0, 0, 1], dtype=torch.float)
        # M_ = M_inverse[0].clone().detach()
        # M_ = torch.cat((M_, torch.tensor([[0, 0, 1]], dtype=torch.float)))
        #
        # M_n = M[0].clone().detach()
        # M_n = torch.cat((M_n, torch.tensor([[0, 0, 1]], dtype=torch.float)))
        #
        # start_tensor = torch.matmul(torch.matmul(point, M_) + margin, M_n)
        # print(start_tensor)
        # start_y, start_x = round(float(start_tensor[0])), round(float(start_tensor[1]))

        # reinsert into tensor
        self.fake_B_aligned_margin[0, :, start:start + 256,
                                   start:start + 256] = self.real_B

        Image.fromarray(tensor2im(self.real_B_aligned_margin)).save(
            "/home/mo/datasets/ff_aligned_unaligned/real.png")
        Image.fromarray(tensor2im(self.fake_B_aligned_margin)).save(
            "/home/mo/datasets/ff_aligned_unaligned/fake.png")

        exit()
        # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # #
        if not alignUnaligned:
            # Now apply the transformation to original image
            # clone fake
            fake_B_clone = self.fake_B.clone().requires_grad_(True)
            # apply warp
            fake_B_warped: torch.tensor = kornia.warp_affine(
                fake_B_clone, M, dsize=(h_original, w_original))

            # make sure warping does not exceed real_B_unaligned_full dimensions
            if y_start < 0:
                fake_B_warped = fake_B_warped[:, :, abs(y_start):h_original, :]
                h_original += y_start
                y_start = 0
            if x_start < 0:
                fake_B_warped = fake_B_warped[:, :, :, abs(x_start):w_original]
                w_original += x_start
                x_start = 0
            if y_start + h_original > h_tensor:
                h_original -= (y_start + h_original - h_tensor)
                fake_B_warped = fake_B_warped[:, :, 0:h_original, :]
            if x_start + w_original > w_tensor:
                w_original -= (x_start + w_original - w_tensor)
                fake_B_warped = fake_B_warped[:, :, :, 0:w_original]

            # create mask that is true where fake_B_warped is 0
            # This is the background that is not filled with image after the transformation
            mask = ((fake_B_warped[0][0] == 0) & (fake_B_warped[0][1] == 0) &
                    (fake_B_warped[0][2] == 0))
            # fill fake_B_filled where mask = False with self.real_B_unaligned_full
            fake_B_filled = torch.where(
                mask,
                self.real_B_unaligned_full[:, :, y_start:y_start + h_original,
                                           x_start:x_start + w_original],
                fake_B_warped)

            # reinsert into tensor
            self.fake_B_unaligned = self.real_B_unaligned_full.clone(
            ).requires_grad_(True)
            mask = torch.zeros_like(self.fake_B_unaligned, dtype=torch.bool)
            mask[0, :, y_start:y_start + h_original,
                 x_start:x_start + w_original] = True
            self.fake_B_unaligned = self.fake_B_unaligned.masked_scatter(
                mask, fake_B_filled)

            # cutout tensor
            h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:]
            margin = max(
                min(
                    y_start - max(0, y_start - margin),
                    x_start - max(0, x_start - margin),
                    min(y_start + h_original + margin, h_size_tensor) -
                    y_start - h_original,
                    min(x_start + w_original + margin, w_size_tensor) -
                    x_start - w_original,
                ), 0)
            self.fake_B_unaligned = self.fake_B_unaligned[:, :, y_start -
                                                          margin:y_start +
                                                          h_original + margin,
                                                          x_start -
                                                          margin:x_start +
                                                          w_original + margin]
            self.real_B_unaligned = self.real_B_unaligned_full[:, :, y_start -
                                                               margin:y_start +
                                                               h_original +
                                                               margin,
                                                               x_start -
                                                               margin:x_start +
                                                               w_original +
                                                               margin]