Ejemplo n.º 1
0
def conversion(video_path_A):
    if len(sys.argv)<=2:
        output_video_path_A = "./Aout.mp4"
    elif len(sys.argv)==3:
        output_video_path_A = sys.argv[3]

    K.set_learning_phase(0)

    # Input/Output resolution
    RESOLUTION = 64 # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, 256"

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config['norm'] = "instancenorm" # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard" # standard, lite

    model = FaceswapGANModel(**arch_config)

    model.load_weights(path="./models")

    mtcnn_weights_dir = "./mtcnn_weights/"

    fd = MTCNNFaceDetector(sess=K.get_session(), model_path=mtcnn_weights_dir)
    vc = VideoConverter()

    vc.set_face_detector(fd)
    vc.set_gan_model(model)

    options = {
        # ===== Fixed =====
        "use_smoothed_bbox": True,
        "use_kalman_filter": True,
        "use_auto_downscaling": False,
        "bbox_moving_avg_coef": 0.65,
        "min_face_area": 35 * 35,
        "IMAGE_SHAPE": model.IMAGE_SHAPE,
        # ===== Tunable =====
        "kf_noise_coef": 3e-3,
        "use_color_correction": "hist_match",
        "detec_threshold": 0.7,
        "roi_coverage": 0.9,
        "enhance": 0.5,
        "output_type": 3,
        "direction": "AtoB",
    }

    input_fn = video_path_A
    output_fn = output_video_path_A
    duration = None
    vc.convert(input_fn=input_fn, output_fn=output_fn, options=options, duration=duration)
Ejemplo n.º 2
0
def reset_session(save_path, model, vggface, train_batchA, train_batchB):
    model.save_weights(path=save_path)
    K.clear_session()
    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=save_path)
    #vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))
    vggface.load_weights("rcmalli_vggface_tf_notop_resnet50.h5")
    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
Ejemplo n.º 3
0
    def reset_session(save_path, model, person='A'):
        model.save_weights(path=save_path)
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
        if person == 'A':
            train_batch = DataLoader(gen_person_img, all_img, batchSize, gen_img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)
        else:
            train_batch = DataLoader(person_img, all_img, batchSize, img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)

        return model, vggface, train_batch
Ejemplo n.º 4
0
def reset_session(save_path):
    global model, vggface
    global train_batchA, train_batchB
    model.save_weights(path=save_path)
    del model
    del vggface
    del train_batchA
    del train_batchB
    K.clear_session()
    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=save_path)
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
Ejemplo n.º 5
0
    # 计数
    count = 1
    #
    # # Path to saved model weights
    models_dir = "./models"
    RESOLUTION = 256  # 64x64, 128x128, 256x256
    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    # TODO 归一化设置
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    model = FaceswapGANModel(**arch_config)

    options = {
        # ===== Fixed =====
        "use_smoothed_bbox": True,
        "use_kalman_filter": True,
        "use_auto_downscaling": False,
        "bbox_moving_avg_coef": 0.70,  # 0.65
        "min_face_area": 128 * 128,
        "IMAGE_SHAPE": model.IMAGE_SHAPE,
        # ===== Tunable =====
        "kf_noise_coef": 1e-3,
        "use_color_correction": "hist_match",
        "detec_threshold": 0.8,
        "roi_coverage": 0.92,
        "enhance": 0.,
Ejemplo n.º 6
0
# Init. loss config.
loss_config = {}
loss_config["gan_training"] = "mixup_LSGAN"
loss_config['use_PL'] = False
loss_config["PL_before_activ"] = True
loss_config['use_mask_hinge_loss'] = False
loss_config['m_mask'] = 0.
loss_config['lr_factor'] = 1.
loss_config['use_cyclic_loss'] = False

from networks.faceswap_gan_model import FaceswapGANModel
from data_loader.data_loader import DataLoader
from utils import showG, showG_mask, showG_eyes

model = FaceswapGANModel(**arch_config)

from keras_vggface.vggface import VGGFace
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
model.build_train_functions(loss_weights=loss_weights, **loss_config)


# Create ./models directory
Path(f"models").mkdir(parents=True, exist_ok=True)

# Get filenames
train_A = glob.glob(img_dirA+"/*.*")
train_B = glob.glob(img_dirB+"/*.*")

train_AnB = train_A + train_B
Ejemplo n.º 7
0
loss_weights['w_eyes'] = 30.  # reconstruction and edge loss on eyes area
loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1
                        )  # perceptual loss (0.003, 0.03, 0.3, 0.3)

# Init. loss config.
loss_config = {}
loss_config[
    "gan_training"] = "mixup_LSGAN"  # "mixup_LSGAN" or "relativistic_avg_LSGAN"
loss_config['use_PL'] = False
loss_config["PL_before_activ"] = False
loss_config['use_mask_hinge_loss'] = False
loss_config['m_mask'] = 0.
loss_config['lr_factor'] = 1.
loss_config['use_cyclic_loss'] = False

model = FaceswapGANModel(**arch_config)
model.load_weights(path=models_dir)

# VGGFace ResNet50
vggface = VGGFace(include_top=False,
                  model='resnet50',
                  input_shape=(224, 224, 3))

#vggface.summary()

model.build_pl_model(vggface_model=vggface,
                     before_activ=loss_config["PL_before_activ"])

model.build_train_functions(loss_weights=loss_weights, **loss_config)

model.load_weights(path=models_dir)
Ejemplo n.º 8
0
def train_person(person, gen_person):
    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Batch size
    batchSize = 2
    assert (batchSize != 1 and batchSize % 2 == 0), "batchSize should be an even number."

    # Path to training images
    img_dir = f'./faces/{person}'
    img_dir_bm_eyes = f"./binary_masks/{person}"
    
    gen_img_dir = f'./faces/{gen_person}'
    gen_img_dir_bm_eyes = f"./binary_masks/{gen_person}"

    # Path to saved model weights
    models_dir = f"./models/{gen_person}2{person}"

    da_config, arch_config, loss_weights, loss_config = get_model_params()

    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=models_dir)

    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))
    vggface.load_weights("rcmalli_vggface_tf_notop_resnet50.h5")

    # VGGFace ResNet50
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    # Create ./models directory
    Path(models_dir).mkdir(parents=True, exist_ok=True)

    # Get filenames
    person_img = glob.glob(img_dir + f"/raw_faces/*.*")
    gen_person_img = glob.glob(gen_img_dir + f"/raw_faces/*.*")
    all_img = person_img + gen_person_img

    assert len(person_img), "No image found in " + str(img_dir)
    print("Number of images in folder: " + str(len(person_img)))

    if da_config["use_bm_eyes"]:
        assert len(glob.glob(img_dir_bm_eyes + "/*.*")), "No binary mask found in " + str(img_dir_bm_eyes)
        assert len(glob.glob(img_dir_bm_eyes + "/*.*")) == len(person_img), \
            "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."

    train_batchA = DataLoader(person_img, all_img, batchSize, img_dir_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(gen_person_img, all_img, batchSize, gen_img_dir_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    # _, tA, bmA = train_batchA.get_next_batch()
    # _, tB, bmB = train_batchB.get_next_batch()
    # showG_eyes(tA, tB, bmA, bmB, batchSize)

    t0 = time.time()
    gen_iterations = 0

    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = 10000

    def reset_session(save_path, model, person='A'):
        model.save_weights(path=save_path)
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
        if person == 'A':
            train_batch = DataLoader(gen_person_img, all_img, batchSize, gen_img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)
        else:
            train_batch = DataLoader(person_img, all_img, batchSize, img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)

        return model, vggface, train_batch

    while gen_iterations <= TOTAL_ITERS:
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            # データのswapが割と肝っぽいぞ
            # よく考えたら当たり前だけども(従来のDAを作ることが目的ではない,千賀の画像を入力して千賀の画像が出てきても
            # ダメじゃん,人が変わらなきゃ)
            model.decoder_A.load_weights(f"{models_dir}/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            model, vggface, train_batchA = reset_session(models_dir, model, person='B')
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            model, vggface, train_batchA = reset_session(models_dir, model, person='B')
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")

        if gen_iterations == 5:
            print("working.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        errDA = model.train_one_batch_D(data_A=data_A)
        errDA_sum += errDA[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        errGA = model.train_one_batch_G(data_A=data_A)
        errGA_sum += errGA[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
        gen_iterations += 1

        # Visualization
        if gen_iterations % display_iters == 0:
            clear_output()

            # Display loss information
            show_loss_config(loss_config)
            print("----------")
            print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
                  % (gen_iterations, errDA_sum / display_iters, errDB_sum / display_iters,
                     errGA_sum / display_iters, errGB_sum / display_iters, time.time() - t0))
            print("----------")
            print("Generator loss details:")
            print(f'[Adversarial loss]')
            print(f'GA: {errGAs["adv"] / display_iters:.4f} GB: {errGBs["adv"] / display_iters:.4f}')
            print(f'[Reconstruction loss]')
            print(f'GA: {errGAs["recon"] / display_iters:.4f} GB: {errGBs["recon"] / display_iters:.4f}')
            print(f'[Edge loss]')
            print(f'GA: {errGAs["edge"] / display_iters:.4f} GB: {errGBs["edge"] / display_iters:.4f}')
            if loss_config['use_PL'] == True:
                print(f'[Perceptual loss]')
                try:
                    print(f'GA: {errGAs["pl"][0] / display_iters:.4f}')
                except:
                    print(f'GA: {errGAs["pl"] / display_iters:.4f}')

            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0

            # Display images
            # print("----------")
            # wA, tA, _ = train_batchA.get_next_batch()
            # print("Transformed (masked) results:")
            # showG(tA, tB, model.path_A, model.path_B, batchSize)
            # print("Masks:")
            # showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)
            # print("Reconstruction results:")
            # showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)

            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)
Ejemplo n.º 9
0
def test_faceswap(person, model_path, test_path, save_path):
    mtcnn_weights_dir = "./mtcnn_weights/"
    fd = MTCNNFaceDetector(sess=K.get_session(), model_path=mtcnn_weights_dir)

    da_config, arch_config, loss_weights, loss_config = get_model_params()

    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=model_path)

    ftrans = FaceTransformer()
    ftrans.set_model(model)

    # Read input image
    test_imgs = glob.glob(test_path + '/*.jpg')
    Path(save_path).mkdir(parents=True, exist_ok=True)

    for test_img in test_imgs:
        input_img = plt.imread(test_img)[..., :3]

        if input_img.dtype == np.float32:
            print("input_img has dtype np.float32 (perhaps the image format is PNG). Scale it to uint8.")
            input_img = (input_img * 255).astype(np.uint8)

        # Display detected face
        faces, lms = fd.detect_face(input_img)
        if len(faces) == 0:
            continue
        x0, y1, x1, y0, _ = faces[0]
        det_face_im = input_img[int(x0):int(x1), int(y0):int(y1), :]
        try:
            src_landmarks = get_src_landmarks(x0, x1, y0, y1, lms)
            tar_landmarks = get_tar_landmarks(det_face_im)
            aligned_det_face_im = landmarks_match_mtcnn(det_face_im, src_landmarks, tar_landmarks)
        except:
            print("An error occured during face alignment.")
            aligned_det_face_im = det_face_im
        # plt.imshow(aligned_det_face_im)
        # Transform detected face
        result_img, result_rgb, result_mask = ftrans.transform(
            aligned_det_face_im,
            direction="BtoA",
            roi_coverage=0.93,
            color_correction="adain_xyz",
            IMAGE_SHAPE=(RESOLUTION, RESOLUTION, 3)
        )
        try:
            result_img = landmarks_match_mtcnn(result_img, tar_landmarks, src_landmarks)
            result_rgb = landmarks_match_mtcnn(result_rgb, tar_landmarks, src_landmarks)
            result_mask = landmarks_match_mtcnn(result_mask, tar_landmarks, src_landmarks)
        except:
            print("An error occured during face alignment.")
            pass

        result_input_img = input_img.copy()
        result_input_img[int(x0):int(x1), int(y0):int(y1), :] = result_mask.astype(np.float32) / 255 * result_rgb + \
                                                                (1 - result_mask.astype(
                                                                    np.float32) / 255) * result_input_img[int(x0):int(x1),
                                                                                         int(y0):int(y1), :]

        img_name = os.path.basename(test_img)
        plt.imshow(result_input_img)
        plt.imsave(f'{save_path}/{img_name}', result_input_img)
Ejemplo n.º 10
0
def train(ITERS):
    K.set_learning_phase(1)
    #K.set_learning_phase(0) # set to 0 in inference phase

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64  # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    # batchSize = 8
    batchSize = 2

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True
    use_bm_eyes = False
    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    session_config = tf.ConfigProto(log_device_placement=False,
                                    allow_soft_placement=True)
    # please do not use the totality of the GPU memory
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.90

    # Path to training images
    dataDir = '/home/wh/hdd/work/xj/AI_anchors/faceswap-GAN-master'
    img_dirA = './faces/faceA/aligned_faces'  #经过对齐后人脸的文件夹
    img_dirB = './faces/faceB/aligned_faces'
    img_dirA_bm_eyes = "./faces/faceA/binary_masks_eyes"  #与对齐人脸对应的mask,用于合成,主要是GAN如果不加mask会差一些
    img_dirB_bm_eyes = "./faces/faceB/binary_masks_eyes"

    # Path to saved model weights
    models_dir = "./models"
    #Path(f"models").mkdir(parents=True, exist_ok=True)
    if not os.path.exists(models_dir):
        os.mkdir(models_dir)

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    # Loss function weights configuration,各种loss的权重系数
    loss_weights = {}
    loss_weights['w_D'] = 0.1  # Discriminator
    loss_weights['w_recon'] = 1.  # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1  # edge loss
    loss_weights['w_eyes'] = 30.  # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1
                            )  # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.  迭代多少次后会用哪种loss
    loss_config = {}
    loss_config[
        "gan_training"] = "mixup_LSGAN"  # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    print('CONFIGURE DONE')
    print('step1')

    # Get filenames
    #train_A = glob.glob(img_dirA+"/*.*")
    #train_B = glob.glob(img_dirB+"/*.*")
    train_A = getDirFnPath(img_dirA)
    train_B = getDirFnPath(img_dirB)

    train_AnB = train_A + train_B

    assert len(train_A), "No image found in " + str(img_dirA)
    assert len(train_B), "No image found in " + str(img_dirB)
    print("Number of images in folder A: " + str(len(train_A)))
    print("Number of images in folder B: " + str(len(train_B)))

    print('step2')

    #define models
    global model
    model = FaceswapGANModel(**arch_config)

    print('DEFINE MODELS DONE')

    model.load_weights(path=models_dir)

    # VGGFace ResNet50,
    # 用vgg的目的是为了对比提取的特征是否一致(将原始图和生成图用VGG提取固定层feat,然后比较feat的距离)
    global vggface
    vggface = VGGFace(include_top=False,
                      model='resnet50',
                      input_shape=(224, 224, 3))

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface)

    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    print('BUILD MODELS DONE')

    if use_bm_eyes:
        assert len(glob.glob(
            img_dirA_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(
            img_dirB_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."

    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False,
                          model='resnet50',
                          input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface)
        train_batchA = DataLoader(train_A, train_AnB, batchSize,
                                  img_dirA_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize,
                                  img_dirB_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)

    print('RESET_SESSION DONE')
    print('step3')

    # Start training
    t0 = time.time()
    gen_iterations = 0
    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = ITERS // 1
    # TOTAL_ITERS = 10000

    print('step4')

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize,
                              img_dirA_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize,
                              img_dirB_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)

    print('step5')
    print('DATALOADER DONE')

    show_iters = 50
    print("START TRAINING")
    while gen_iterations <= TOTAL_ITERS:
        #print(gen_iterations)
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 -
                                display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            model.decoder_A.load_weights(
                "models/decoder_B.h5")  # swap decoders
            model.decoder_B.load_weights(
                "models/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch(
        )  # 每次返回一张原始图,一张随机扭曲的图,以及与原始图像对应的mask
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum += errDA[0]
        errDB_sum += errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations += 1

        if gen_iterations % show_iters == 0:
            #print('iter: %d, Loss_GA: %f, Loss_G: %f' % (gen_iterations, errGA_sum / display_iters, errGB_sum / display_iters))
            # Display loss information
            #show_loss_config(loss_config)
            print(
                '[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
                %
                (gen_iterations, float(errDA_sum / show_iters),
                 float(errDB_sum / show_iters), float(errGA_sum / show_iters),
                 float(errGB_sum / show_iters), time.time() - t0))
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0
            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0

            # Display images
            wA, tA, _ = train_batchA.get_next_batch()
            wB, tB, _ = train_batchB.get_next_batch()
            tran_res = showG(tA, tB, model.path_A, model.path_B, batchSize)
            mask_res = showG_mask(tA, tB, model.path_mask_A, model.path_mask_B,
                                  batchSize)
            rec_res = showG(wA, wB, model.path_bgr_A, model.path_bgr_B,
                            batchSize)
            #mask_res1 = cv2.cvtColor(mask_res, cv2.COLOR_GRAY2BGR)
            fname = "./logs/images/tran_mask_rec_%d.jpg" % (gen_iterations)
            res = np.vstack([tran_res, mask_res, rec_res])
            cv2.imwrite(fname, res)
            #fname = "./logs/images/mask_%d.jpg" % (gen_iterations)
            #cv2.imwrite(fname, mask_res)

        # Visualization delete
        if gen_iterations % display_iters == 0:
            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            #Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            if not os.path.exists(bkup_dir):
                os.mkdir(bkup_dir)
            model.save_weights(path=bkup_dir)

    print('TRAIN DONE')
Ejemplo n.º 11
0
def run(img_dirA, img_dirB, TOTAL_ITERS = 40000):
    global model, vggface
    
    img_dirA_bm_eyes = f"{img_dirA}/binary_masks_eyes2"
    img_dirB_bm_eyes = f"{img_dirB}/binary_masks_eyes2"

    # Path to saved model weights
    models_dir = "./models"

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64 # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    batchSize = 8
    assert (batchSize != 1 and batchSize % 2 == 0) , "batchSize should be an even number."

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False 

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True

    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config['norm'] = "instancenorm" # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard" # standard, lite

    # VGGFace ResNet50
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

    # Loss function weights configuration
    loss_weights = {}
    loss_weights['w_D'] = 0.1 # Discriminator
    loss_weights['w_recon'] = 1. # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1 # edge loss
    loss_weights['w_eyes'] = 30. # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1) # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.
    loss_config = {}
    loss_config["gan_training"] = "mixup_LSGAN" # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config["PL_before_activ"] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    from networks.faceswap_gan_model import FaceswapGANModel
    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=models_dir)

    from colab_demo.vggface_models import RESNET50
    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))
    vggface.load_weights("rcmalli_vggface_tf_notop_resnet50.h5")

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    model.build_train_functions(loss_weights=loss_weights, **loss_config)
    from data_loader.data_loader import DataLoader
    # Create ./models directory
    Path(f"models").mkdir(parents=True, exist_ok=True)

    # Get filenames
    faces_A = f"{img_dirA}/raw_faces"
    faces_B = f"{img_dirB}/raw_faces"

    train_A = glob.glob(f"{faces_A}/*.*")
    train_B = glob.glob(f"{faces_B}/*.*")

    train_AnB = train_A + train_B

    assert len(train_A), f"No image found in {faces_A}"
    assert len(train_B), f"No image found in {faces_B}"
    print ("Number of images in folder A: " + str(len(train_A)))
    print ("Number of images in folder B: " + str(len(train_B)))

    if use_bm_eyes:
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."


    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
        train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,
                                  RESOLUTION, num_cpus, K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, 
                                  RESOLUTION, num_cpus, K.get_session(), **da_config)

    # Start training
    t0 = time.time()

    # This try/except is meant to resume training that was accidentally interrupted
    try:
        gen_iterations
        print(f"Resume training from iter {gen_iterations}.")
    except:
        gen_iterations = 0

    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, 
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, 
                              RESOLUTION, num_cpus, K.get_session(), **da_config)

    clear_output()
    loss_config['use_PL'] = True
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.0
    reset_session(models_dir)
    print("Building new loss funcitons...")
    show_loss_config(loss_config)
    model.build_train_functions(loss_weights=loss_weights, **loss_config)
    print("Done.")

    while gen_iterations <= TOTAL_ITERS: 

        # Loss function automation
        if gen_iterations == (TOTAL_ITERS//5 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Complete.")
        elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (8*TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            model.decoder_A.load_weights("models/decoder_B.h5") # swap decoders
            model.decoder_B.load_weights("models/decoder_A.h5") # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")

        if gen_iterations == 5:
            print ("working.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum +=errDA[0]
        errDB_sum +=errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations+=1

        # Visualization
        if gen_iterations % display_iters == 0:
            clear_output()

            # Display loss information
            show_loss_config(loss_config)
            print("----------") 
            print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
            % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,
               errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  
            print("----------") 
            print("Generator loss details:")
            print(f'[Adversarial loss]')  
            print(f'GA: {errGAs["adv"]/display_iters:.4f} GB: {errGBs["adv"]/display_iters:.4f}')
            print(f'[Reconstruction loss]')
            print(f'GA: {errGAs["recon"]/display_iters:.4f} GB: {errGBs["recon"]/display_iters:.4f}')
            print(f'[Edge loss]')
            print(f'GA: {errGAs["edge"]/display_iters:.4f} GB: {errGBs["edge"]/display_iters:.4f}')
            if loss_config['use_PL'] == True:
                print(f'[Perceptual loss]')
                try:
                    print(f'GA: {errGAs["pl"][0]/display_iters:.4f} GB: {errGBs["pl"][0]/display_iters:.4f}')
                except:
                    print(f'GA: {errGAs["pl"]/display_iters:.4f} GB: {errGBs["pl"]/display_iters:.4f}')

            # Display images
            print("----------") 
            wA, tA, _ = train_batchA.get_next_batch()
            wB, tB, _ = train_batchB.get_next_batch()
            print("Transformed (masked) results:")
            showG(tA, tB, model.path_A, model.path_B, batchSize)   
            print("Masks:")
            showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  
            print("Reconstruction results:")
            showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           
            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0

            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0: 
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)
Ejemplo n.º 12
0
def face_transform():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
    session = tf.Session(config=config)
    # 设置session
    K.set_session(session)
    models_dir = "../models"
    RESOLUTION = 256  # 64x64, 128x128, 256x256
    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True

    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    model = FaceswapGANModel(**arch_config)

    options = {
        # ===== Fixed =====
        "use_smoothed_bbox": True,
        "use_kalman_filter": True,
        "use_auto_downscaling": False,
        "bbox_moving_avg_coef": 0.70,  # 0.65
        "min_face_area": 128 * 128,
        "IMAGE_SHAPE": model.IMAGE_SHAPE,
        # ===== Tunable =====
        "kf_noise_coef": 1e-3,
        "use_color_correction": "hist_match",
        "detec_threshold": 0.8,
        "roi_coverage": 0.92,
        "enhance": 0.,
        "output_type": 1,
        "direction":
        "BtoA",  # ==================== This line determines the transform direction ====================
    }

    model.load_weights(path=models_dir)
    fd = MTCNNFaceDetector(sess=K.get_session(),
                           model_path="../mtcnn_weights/")
    vc = VideoConverter()
    vc.set_face_detector(fd)
    vc.set_gan_model(model)
    vc._init_kalman_filters(options["kf_noise_coef"])

    def transform(imageB64):
        """
        :param imageB64: 图片base64编码
        :return: 转换后的图片base64编码
        """
        rgb_img = base64_to_image(imageB64)

        result = vc.process_video(rgb_img, options)
        result = normalization(result)

        r, g, b = cv2.split(result)
        img_bgr = cv2.merge([b, g, r])
        return np.array(result).tolist()
        # bgr base64
        # res_base64 = image_to_base64(img_bgr)
        # return res_base64

    return transform
from networks.faceswap_gan_model import FaceswapGANModel

import numpy as np
K.set_learning_phase(0)
RESOLUTION = 128  # 64x64, 128x128, 256x256
assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, 256"

# Architecture configuration
arch_config = {}
arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
arch_config['use_self_attn'] = True
arch_config[
    'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
arch_config['model_capacity'] = "standard"  # standard, lite

model = FaceswapGANModel(**arch_config)

model.load_weights(path="./endeweight")

from converter.video_converter import VideoConverter

from detector.face_detector import MTCNNFaceDetector

mtcnn_weights_dir = "./weights/"

fd = MTCNNFaceDetector(sess=K.get_session(), model_path=mtcnn_weights_dir)

frames = 0
x0 = x1 = y0 = y1 = 0
vc = VideoConverter(x0, x1, y0, y1, frames)
Ejemplo n.º 14
0
def train(ITERS):
    K.set_learning_phase(1)
    #K.set_learning_phase(0) # set to 0 in inference phase

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64  # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    # batchSize = 8
    batchSize = 4

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True
    #use_bm_eyes = False
    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    session_config = tf.ConfigProto(log_device_placement=False,
                                    allow_soft_placement=True)
    # please do not use the totality of the GPU memory
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.90

    # Path to training images
    img_dirA = './faceA'
    img_dirB = './faceB'
    #img_dirA_bm_eyes = "./binary_masks/faceA_eyes"
    #img_dirB_bm_eyes = "./binary_masks/faceB_eyes"
    img_dirA_bm_eyes = "./faceA/binary_masks_eyes"
    img_dirB_bm_eyes = "./faceB/binary_masks_eyes"

    # Path to saved model weights
    models_dir = "./models"
    Path(f"models").mkdir(parents=True, exist_ok=True)

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    # Loss function weights configuration
    loss_weights = {}
    loss_weights['w_D'] = 0.1  # Discriminator
    loss_weights['w_recon'] = 1.  # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1  # edge loss
    loss_weights['w_eyes'] = 30.  # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1
                            )  # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.
    loss_config = {}
    loss_config[
        "gan_training"] = "mixup_LSGAN"  # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    print('CONFIGURE DONE')

    #define models
    from networks.faceswap_gan_model import FaceswapGANModel
    global model
    model = FaceswapGANModel(**arch_config)

    print('DEFINE MODELS DONE')

    model.load_weights(path=models_dir)

    from keras_vggface.vggface import VGGFace
    # VGGFace ResNet50
    global vggface
    vggface = VGGFace(include_top=False,
                      model='resnet50',
                      input_shape=(224, 224, 3))

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface)

    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    from data_loader.data_loader import DataLoader

    # Get filenames
    train_A = glob.glob(img_dirA + "/*.*")
    train_B = glob.glob(img_dirB + "/*.*")

    train_AnB = train_A + train_B

    assert len(train_A), "No image found in " + str(img_dirA)
    assert len(train_B), "No image found in " + str(img_dirB)
    print("Number of images in folder A: " + str(len(train_A)))
    print("Number of images in folder B: " + str(len(train_B)))

    if use_bm_eyes:
        assert len(glob.glob(
            img_dirA_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(
            img_dirB_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."

    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False,
                          model='resnet50',
                          input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface)
        train_batchA = DataLoader(train_A, train_AnB, batchSize,
                                  img_dirA_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize,
                                  img_dirB_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)

    print('RESET_SESSION DONE')

    # Start training
    t0 = time.time()
    gen_iterations = 0
    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = ITERS // 1
    # TOTAL_ITERS = 10000

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize,
                              img_dirA_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize,
                              img_dirB_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)

    print('DATALOADER DONE')

    print("START TRAINING")
    while gen_iterations <= TOTAL_ITERS:
        print(gen_iterations)
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 -
                                display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            model.decoder_A.load_weights(
                "models/decoder_B.h5")  # swap decoders
            model.decoder_B.load_weights(
                "models/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum += errDA[0]
        errDB_sum += errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations += 1

        # Visualization delete
        if gen_iterations % display_iters == 0:
            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)

    print('TRAIN DONE')
Ejemplo n.º 15
0
def transform_img(inStack, outStack):
    print('Process to transform: %s    ' % os.getpid(), time.time())

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
    session = tf.Session(config=config)
    # 设置session
    K.set_session(session)

    models_dir = "../models_transform"
    RESOLUTION = 256  # 64x64, 128x128, 256x256
    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    # TODO 归一化设置
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    model = FaceswapGANModel(**arch_config)

    options = {
        # ===== Fixed =====
        "use_smoothed_bbox": True,
        "use_kalman_filter": True,
        "use_auto_downscaling": False,
        "bbox_moving_avg_coef": 0.65,  # 0.65
        "min_face_area": 35 * 35,
        "IMAGE_SHAPE": model.IMAGE_SHAPE,
        # ===== Tunable =====
        "kf_noise_coef": 1e-3,
        "use_color_correction": "hist_match",
        "detec_threshold": 0.8,
        "roi_coverage": 0.90,
        "enhance": 0.,
        "output_type": 1,
        "direction":
        "BtoA",  # ==================== This line determines the transform direction ====================
    }

    model.load_weights(path=models_dir)
    fd = MTCNNFaceDetector(sess=K.get_session(),
                           model_path="../mtcnn_weights/")
    vc = VideoConverter()
    vc.set_face_detector(fd)
    vc.set_gan_model(model)
    vc._init_kalman_filters(options["kf_noise_coef"])

    while True:
        if len(inStack) != 0:
            start_time = time.time()
            rgb_img = inStack.pop()
            # print("transform_img inputQ size ;",inputQ.qsize())
            # 获取转换后的人脸
            result = vc.process_video(rgb_img, options)
            result = normalization(result) * 255
            result = np.uint8(result)

            print(time.time() - start_time)
            outStack.append(result)