Exemplo n.º 1
0
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.0
        loss_config['lr_factor'] = 0.1
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")

    if gen_iterations == 5:
        print ("working.")

    # Train dicriminators for one batch
    data_A = train_batchA.get_next_batch()
    data_B = train_batchB.get_next_batch()
    errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
    errDA_sum +=errDA[0]
    errDB_sum +=errDB[0]

    # Train generators for one batch
    data_A = train_batchA.get_next_batch()
    data_B = train_batchB.get_next_batch()
    errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
    errGA_sum += errGA[0]
    errGB_sum += errGB[0]
    for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
        errGAs[k] += errGA[i]
        errGBs[k] += errGB[i]
    gen_iterations+=1
Exemplo n.º 2
0
def train_person(person, gen_person):
    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Batch size
    batchSize = 2
    assert (batchSize != 1 and batchSize % 2 == 0), "batchSize should be an even number."

    # Path to training images
    img_dir = f'./faces/{person}'
    img_dir_bm_eyes = f"./binary_masks/{person}"
    
    gen_img_dir = f'./faces/{gen_person}'
    gen_img_dir_bm_eyes = f"./binary_masks/{gen_person}"

    # Path to saved model weights
    models_dir = f"./models/{gen_person}2{person}"

    da_config, arch_config, loss_weights, loss_config = get_model_params()

    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=models_dir)

    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))
    vggface.load_weights("rcmalli_vggface_tf_notop_resnet50.h5")

    # VGGFace ResNet50
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    # Create ./models directory
    Path(models_dir).mkdir(parents=True, exist_ok=True)

    # Get filenames
    person_img = glob.glob(img_dir + f"/raw_faces/*.*")
    gen_person_img = glob.glob(gen_img_dir + f"/raw_faces/*.*")
    all_img = person_img + gen_person_img

    assert len(person_img), "No image found in " + str(img_dir)
    print("Number of images in folder: " + str(len(person_img)))

    if da_config["use_bm_eyes"]:
        assert len(glob.glob(img_dir_bm_eyes + "/*.*")), "No binary mask found in " + str(img_dir_bm_eyes)
        assert len(glob.glob(img_dir_bm_eyes + "/*.*")) == len(person_img), \
            "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."

    train_batchA = DataLoader(person_img, all_img, batchSize, img_dir_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(gen_person_img, all_img, batchSize, gen_img_dir_bm_eyes,
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    # _, tA, bmA = train_batchA.get_next_batch()
    # _, tB, bmB = train_batchB.get_next_batch()
    # showG_eyes(tA, tB, bmA, bmB, batchSize)

    t0 = time.time()
    gen_iterations = 0

    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = 10000

    def reset_session(save_path, model, person='A'):
        model.save_weights(path=save_path)
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
        if person == 'A':
            train_batch = DataLoader(gen_person_img, all_img, batchSize, gen_img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)
        else:
            train_batch = DataLoader(person_img, all_img, batchSize, img_dir_bm_eyes,
                                      RESOLUTION, num_cpus, K.get_session(), **da_config)

        return model, vggface, train_batch

    while gen_iterations <= TOTAL_ITERS:
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            model, vggface, train_batchA = reset_session(models_dir, model)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            # データのswapが割と肝っぽいぞ
            # よく考えたら当たり前だけども(従来のDAを作ることが目的ではない,千賀の画像を入力して千賀の画像が出てきても
            # ダメじゃん,人が変わらなきゃ)
            model.decoder_A.load_weights(f"{models_dir}/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            model, vggface, train_batchA = reset_session(models_dir, model, person='B')
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            model, vggface, train_batchA = reset_session(models_dir, model, person='B')
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")

        if gen_iterations == 5:
            print("working.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        errDA = model.train_one_batch_D(data_A=data_A)
        errDA_sum += errDA[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        errGA = model.train_one_batch_G(data_A=data_A)
        errGA_sum += errGA[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
        gen_iterations += 1

        # Visualization
        if gen_iterations % display_iters == 0:
            clear_output()

            # Display loss information
            show_loss_config(loss_config)
            print("----------")
            print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
                  % (gen_iterations, errDA_sum / display_iters, errDB_sum / display_iters,
                     errGA_sum / display_iters, errGB_sum / display_iters, time.time() - t0))
            print("----------")
            print("Generator loss details:")
            print(f'[Adversarial loss]')
            print(f'GA: {errGAs["adv"] / display_iters:.4f} GB: {errGBs["adv"] / display_iters:.4f}')
            print(f'[Reconstruction loss]')
            print(f'GA: {errGAs["recon"] / display_iters:.4f} GB: {errGBs["recon"] / display_iters:.4f}')
            print(f'[Edge loss]')
            print(f'GA: {errGAs["edge"] / display_iters:.4f} GB: {errGBs["edge"] / display_iters:.4f}')
            if loss_config['use_PL'] == True:
                print(f'[Perceptual loss]')
                try:
                    print(f'GA: {errGAs["pl"][0] / display_iters:.4f}')
                except:
                    print(f'GA: {errGAs["pl"] / display_iters:.4f}')

            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0

            # Display images
            # print("----------")
            # wA, tA, _ = train_batchA.get_next_batch()
            # print("Transformed (masked) results:")
            # showG(tA, tB, model.path_A, model.path_B, batchSize)
            # print("Masks:")
            # showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)
            # print("Reconstruction results:")
            # showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)

            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)
Exemplo n.º 3
0
def run(img_dirA, img_dirB, TOTAL_ITERS = 40000):
    global model, vggface
    
    img_dirA_bm_eyes = f"{img_dirA}/binary_masks_eyes2"
    img_dirB_bm_eyes = f"{img_dirB}/binary_masks_eyes2"

    # Path to saved model weights
    models_dir = "./models"

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64 # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    batchSize = 8
    assert (batchSize != 1 and batchSize % 2 == 0) , "batchSize should be an even number."

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False 

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True

    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config['norm'] = "instancenorm" # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard" # standard, lite

    # VGGFace ResNet50
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

    # Loss function weights configuration
    loss_weights = {}
    loss_weights['w_D'] = 0.1 # Discriminator
    loss_weights['w_recon'] = 1. # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1 # edge loss
    loss_weights['w_eyes'] = 30. # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1) # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.
    loss_config = {}
    loss_config["gan_training"] = "mixup_LSGAN" # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config["PL_before_activ"] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    from networks.faceswap_gan_model import FaceswapGANModel
    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=models_dir)

    from colab_demo.vggface_models import RESNET50
    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))
    vggface.load_weights("rcmalli_vggface_tf_notop_resnet50.h5")

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
    model.build_train_functions(loss_weights=loss_weights, **loss_config)
    from data_loader.data_loader import DataLoader
    # Create ./models directory
    Path(f"models").mkdir(parents=True, exist_ok=True)

    # Get filenames
    faces_A = f"{img_dirA}/raw_faces"
    faces_B = f"{img_dirB}/raw_faces"

    train_A = glob.glob(f"{faces_A}/*.*")
    train_B = glob.glob(f"{faces_B}/*.*")

    train_AnB = train_A + train_B

    assert len(train_A), f"No image found in {faces_A}"
    assert len(train_B), f"No image found in {faces_B}"
    print ("Number of images in folder A: " + str(len(train_A)))
    print ("Number of images in folder B: " + str(len(train_B)))

    if use_bm_eyes:
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."


    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
        train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,
                                  RESOLUTION, num_cpus, K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, 
                                  RESOLUTION, num_cpus, K.get_session(), **da_config)

    # Start training
    t0 = time.time()

    # This try/except is meant to resume training that was accidentally interrupted
    try:
        gen_iterations
        print(f"Resume training from iter {gen_iterations}.")
    except:
        gen_iterations = 0

    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, 
                              RESOLUTION, num_cpus, K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, 
                              RESOLUTION, num_cpus, K.get_session(), **da_config)

    clear_output()
    loss_config['use_PL'] = True
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.0
    reset_session(models_dir)
    print("Building new loss funcitons...")
    show_loss_config(loss_config)
    model.build_train_functions(loss_weights=loss_weights, **loss_config)
    print("Done.")

    while gen_iterations <= TOTAL_ITERS: 

        # Loss function automation
        if gen_iterations == (TOTAL_ITERS//5 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Complete.")
        elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (8*TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            model.decoder_A.load_weights("models/decoder_B.h5") # swap decoders
            model.decoder_B.load_weights("models/decoder_A.h5") # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")
        elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):
            clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights, **loss_config)
            print("Done.")

        if gen_iterations == 5:
            print ("working.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum +=errDA[0]
        errDB_sum +=errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations+=1

        # Visualization
        if gen_iterations % display_iters == 0:
            clear_output()

            # Display loss information
            show_loss_config(loss_config)
            print("----------") 
            print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
            % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,
               errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  
            print("----------") 
            print("Generator loss details:")
            print(f'[Adversarial loss]')  
            print(f'GA: {errGAs["adv"]/display_iters:.4f} GB: {errGBs["adv"]/display_iters:.4f}')
            print(f'[Reconstruction loss]')
            print(f'GA: {errGAs["recon"]/display_iters:.4f} GB: {errGBs["recon"]/display_iters:.4f}')
            print(f'[Edge loss]')
            print(f'GA: {errGAs["edge"]/display_iters:.4f} GB: {errGBs["edge"]/display_iters:.4f}')
            if loss_config['use_PL'] == True:
                print(f'[Perceptual loss]')
                try:
                    print(f'GA: {errGAs["pl"][0]/display_iters:.4f} GB: {errGBs["pl"][0]/display_iters:.4f}')
                except:
                    print(f'GA: {errGAs["pl"]/display_iters:.4f} GB: {errGBs["pl"]/display_iters:.4f}')

            # Display images
            print("----------") 
            wA, tA, _ = train_batchA.get_next_batch()
            wB, tB, _ = train_batchB.get_next_batch()
            print("Transformed (masked) results:")
            showG(tA, tB, model.path_A, model.path_B, batchSize)   
            print("Masks:")
            showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  
            print("Reconstruction results:")
            showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           
            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0

            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0: 
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)
Exemplo n.º 4
0
def train(ITERS):
    K.set_learning_phase(1)
    #K.set_learning_phase(0) # set to 0 in inference phase

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64  # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    # batchSize = 8
    batchSize = 2

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True
    use_bm_eyes = False
    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    session_config = tf.ConfigProto(log_device_placement=False,
                                    allow_soft_placement=True)
    # please do not use the totality of the GPU memory
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.90

    # Path to training images
    dataDir = '/home/wh/hdd/work/xj/AI_anchors/faceswap-GAN-master'
    img_dirA = './faces/faceA/aligned_faces'  #经过对齐后人脸的文件夹
    img_dirB = './faces/faceB/aligned_faces'
    img_dirA_bm_eyes = "./faces/faceA/binary_masks_eyes"  #与对齐人脸对应的mask,用于合成,主要是GAN如果不加mask会差一些
    img_dirB_bm_eyes = "./faces/faceB/binary_masks_eyes"

    # Path to saved model weights
    models_dir = "./models"
    #Path(f"models").mkdir(parents=True, exist_ok=True)
    if not os.path.exists(models_dir):
        os.mkdir(models_dir)

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    # Loss function weights configuration,各种loss的权重系数
    loss_weights = {}
    loss_weights['w_D'] = 0.1  # Discriminator
    loss_weights['w_recon'] = 1.  # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1  # edge loss
    loss_weights['w_eyes'] = 30.  # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1
                            )  # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.  迭代多少次后会用哪种loss
    loss_config = {}
    loss_config[
        "gan_training"] = "mixup_LSGAN"  # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    print('CONFIGURE DONE')
    print('step1')

    # Get filenames
    #train_A = glob.glob(img_dirA+"/*.*")
    #train_B = glob.glob(img_dirB+"/*.*")
    train_A = getDirFnPath(img_dirA)
    train_B = getDirFnPath(img_dirB)

    train_AnB = train_A + train_B

    assert len(train_A), "No image found in " + str(img_dirA)
    assert len(train_B), "No image found in " + str(img_dirB)
    print("Number of images in folder A: " + str(len(train_A)))
    print("Number of images in folder B: " + str(len(train_B)))

    print('step2')

    #define models
    global model
    model = FaceswapGANModel(**arch_config)

    print('DEFINE MODELS DONE')

    model.load_weights(path=models_dir)

    # VGGFace ResNet50,
    # 用vgg的目的是为了对比提取的特征是否一致(将原始图和生成图用VGG提取固定层feat,然后比较feat的距离)
    global vggface
    vggface = VGGFace(include_top=False,
                      model='resnet50',
                      input_shape=(224, 224, 3))

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface)

    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    print('BUILD MODELS DONE')

    if use_bm_eyes:
        assert len(glob.glob(
            img_dirA_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(
            img_dirB_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."

    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False,
                          model='resnet50',
                          input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface)
        train_batchA = DataLoader(train_A, train_AnB, batchSize,
                                  img_dirA_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize,
                                  img_dirB_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)

    print('RESET_SESSION DONE')
    print('step3')

    # Start training
    t0 = time.time()
    gen_iterations = 0
    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = ITERS // 1
    # TOTAL_ITERS = 10000

    print('step4')

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize,
                              img_dirA_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize,
                              img_dirB_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)

    print('step5')
    print('DATALOADER DONE')

    show_iters = 50
    print("START TRAINING")
    while gen_iterations <= TOTAL_ITERS:
        #print(gen_iterations)
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 -
                                display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            model.decoder_A.load_weights(
                "models/decoder_B.h5")  # swap decoders
            model.decoder_B.load_weights(
                "models/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch(
        )  # 每次返回一张原始图,一张随机扭曲的图,以及与原始图像对应的mask
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum += errDA[0]
        errDB_sum += errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations += 1

        if gen_iterations % show_iters == 0:
            #print('iter: %d, Loss_GA: %f, Loss_G: %f' % (gen_iterations, errGA_sum / display_iters, errGB_sum / display_iters))
            # Display loss information
            #show_loss_config(loss_config)
            print(
                '[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
                %
                (gen_iterations, float(errDA_sum / show_iters),
                 float(errDB_sum / show_iters), float(errGA_sum / show_iters),
                 float(errGB_sum / show_iters), time.time() - t0))
            for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
                errGAs[k] = 0
                errGBs[k] = 0
            errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0

            # Display images
            wA, tA, _ = train_batchA.get_next_batch()
            wB, tB, _ = train_batchB.get_next_batch()
            tran_res = showG(tA, tB, model.path_A, model.path_B, batchSize)
            mask_res = showG_mask(tA, tB, model.path_mask_A, model.path_mask_B,
                                  batchSize)
            rec_res = showG(wA, wB, model.path_bgr_A, model.path_bgr_B,
                            batchSize)
            #mask_res1 = cv2.cvtColor(mask_res, cv2.COLOR_GRAY2BGR)
            fname = "./logs/images/tran_mask_rec_%d.jpg" % (gen_iterations)
            res = np.vstack([tran_res, mask_res, rec_res])
            cv2.imwrite(fname, res)
            #fname = "./logs/images/mask_%d.jpg" % (gen_iterations)
            #cv2.imwrite(fname, mask_res)

        # Visualization delete
        if gen_iterations % display_iters == 0:
            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            #Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            if not os.path.exists(bkup_dir):
                os.mkdir(bkup_dir)
            model.save_weights(path=bkup_dir)

    print('TRAIN DONE')
Exemplo n.º 5
0
def train(ITERS):
    K.set_learning_phase(1)
    #K.set_learning_phase(0) # set to 0 in inference phase

    # Number of CPU cores
    num_cpus = os.cpu_count()

    # Input/Output resolution
    RESOLUTION = 64  # 64x64, 128x128, 256x256
    assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

    # Batch size
    # batchSize = 8
    batchSize = 4

    # Use motion blurs (data augmentation)
    # set True if training data contains images extracted from videos
    use_da_motion_blur = False

    # Use eye-aware training
    # require images generated from prep_binary_masks.ipynb
    use_bm_eyes = True
    #use_bm_eyes = False
    # Probability of random color matching (data augmentation)
    prob_random_color_match = 0.5

    da_config = {
        "prob_random_color_match": prob_random_color_match,
        "use_da_motion_blur": use_da_motion_blur,
        "use_bm_eyes": use_bm_eyes
    }

    session_config = tf.ConfigProto(log_device_placement=False,
                                    allow_soft_placement=True)
    # please do not use the totality of the GPU memory
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.90

    # Path to training images
    img_dirA = './faceA'
    img_dirB = './faceB'
    #img_dirA_bm_eyes = "./binary_masks/faceA_eyes"
    #img_dirB_bm_eyes = "./binary_masks/faceB_eyes"
    img_dirA_bm_eyes = "./faceA/binary_masks_eyes"
    img_dirB_bm_eyes = "./faceB/binary_masks_eyes"

    # Path to saved model weights
    models_dir = "./models"
    Path(f"models").mkdir(parents=True, exist_ok=True)

    # Architecture configuration
    arch_config = {}
    arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
    arch_config['use_self_attn'] = True
    arch_config[
        'norm'] = "instancenorm"  # instancenorm, batchnorm, layernorm, groupnorm, none
    arch_config['model_capacity'] = "standard"  # standard, lite

    # Loss function weights configuration
    loss_weights = {}
    loss_weights['w_D'] = 0.1  # Discriminator
    loss_weights['w_recon'] = 1.  # L1 reconstruction loss
    loss_weights['w_edge'] = 0.1  # edge loss
    loss_weights['w_eyes'] = 30.  # reconstruction and edge loss on eyes area
    loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1
                            )  # perceptual loss (0.003, 0.03, 0.3, 0.3)

    # Init. loss config.
    loss_config = {}
    loss_config[
        "gan_training"] = "mixup_LSGAN"  # "mixup_LSGAN" or "relativistic_avg_LSGAN"
    loss_config['use_PL'] = False
    loss_config['use_mask_hinge_loss'] = False
    loss_config['m_mask'] = 0.
    loss_config['lr_factor'] = 1.
    loss_config['use_cyclic_loss'] = False

    print('CONFIGURE DONE')

    #define models
    from networks.faceswap_gan_model import FaceswapGANModel
    global model
    model = FaceswapGANModel(**arch_config)

    print('DEFINE MODELS DONE')

    model.load_weights(path=models_dir)

    from keras_vggface.vggface import VGGFace
    # VGGFace ResNet50
    global vggface
    vggface = VGGFace(include_top=False,
                      model='resnet50',
                      input_shape=(224, 224, 3))

    #vggface.summary()

    model.build_pl_model(vggface_model=vggface)

    model.build_train_functions(loss_weights=loss_weights, **loss_config)

    from data_loader.data_loader import DataLoader

    # Get filenames
    train_A = glob.glob(img_dirA + "/*.*")
    train_B = glob.glob(img_dirB + "/*.*")

    train_AnB = train_A + train_B

    assert len(train_A), "No image found in " + str(img_dirA)
    assert len(train_B), "No image found in " + str(img_dirB)
    print("Number of images in folder A: " + str(len(train_A)))
    print("Number of images in folder B: " + str(len(train_B)))

    if use_bm_eyes:
        assert len(glob.glob(
            img_dirA_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
        assert len(glob.glob(
            img_dirB_bm_eyes +
            "/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
        assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
        "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
        assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
        "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."

    def show_loss_config(loss_config):
        for config, value in loss_config.items():
            print(f"{config} = {value}")

    def reset_session(save_path):
        global model, vggface
        global train_batchA, train_batchB
        model.save_weights(path=save_path)
        del model
        del vggface
        del train_batchA
        del train_batchB
        K.clear_session()
        model = FaceswapGANModel(**arch_config)
        model.load_weights(path=save_path)
        vggface = VGGFace(include_top=False,
                          model='resnet50',
                          input_shape=(224, 224, 3))
        model.build_pl_model(vggface_model=vggface)
        train_batchA = DataLoader(train_A, train_AnB, batchSize,
                                  img_dirA_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)
        train_batchB = DataLoader(train_B, train_AnB, batchSize,
                                  img_dirB_bm_eyes, RESOLUTION, num_cpus,
                                  K.get_session(), **da_config)

    print('RESET_SESSION DONE')

    # Start training
    t0 = time.time()
    gen_iterations = 0
    errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
    errGAs = {}
    errGBs = {}
    # Dictionaries are ordered in Python 3.6
    for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
        errGAs[k] = 0
        errGBs[k] = 0

    display_iters = 300
    backup_iters = 5000
    TOTAL_ITERS = ITERS // 1
    # TOTAL_ITERS = 10000

    global train_batchA, train_batchB
    train_batchA = DataLoader(train_A, train_AnB, batchSize,
                              img_dirA_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)
    train_batchB = DataLoader(train_B, train_AnB, batchSize,
                              img_dirB_bm_eyes, RESOLUTION, num_cpus,
                              K.get_session(), **da_config)

    print('DATALOADER DONE')

    print("START TRAINING")
    while gen_iterations <= TOTAL_ITERS:
        print(gen_iterations)
        # Loss function automation
        if gen_iterations == (TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 5 + TOTAL_ITERS // 10 -
                                display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.5
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Complete.")
        elif gen_iterations == (2 * TOTAL_ITERS // 5 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.2
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (TOTAL_ITERS // 2 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.4
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (2 * TOTAL_ITERS // 3 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (8 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            model.decoder_A.load_weights(
                "models/decoder_B.h5")  # swap decoders
            model.decoder_B.load_weights(
                "models/decoder_A.h5")  # swap decoders
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = True
            loss_config['m_mask'] = 0.1
            loss_config['lr_factor'] = 0.3
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")
        elif gen_iterations == (9 * TOTAL_ITERS // 10 - display_iters // 2):
            #clear_output()
            loss_config['use_PL'] = True
            loss_config['use_mask_hinge_loss'] = False
            loss_config['m_mask'] = 0.0
            loss_config['lr_factor'] = 0.1
            reset_session(models_dir)
            print("Building new loss funcitons...")
            show_loss_config(loss_config)
            model.build_train_functions(loss_weights=loss_weights,
                                        **loss_config)
            print("Done.")

        # Train dicriminators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
        errDA_sum += errDA[0]
        errDB_sum += errDB[0]

        # Train generators for one batch
        data_A = train_batchA.get_next_batch()
        data_B = train_batchB.get_next_batch()
        errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
        errGA_sum += errGA[0]
        errGB_sum += errGB[0]
        for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
            errGAs[k] += errGA[i]
            errGBs[k] += errGB[i]
        gen_iterations += 1

        # Visualization delete
        if gen_iterations % display_iters == 0:
            # Save models
            model.save_weights(path=models_dir)

        # Backup models
        if gen_iterations % backup_iters == 0:
            bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
            Path(bkup_dir).mkdir(parents=True, exist_ok=True)
            model.save_weights(path=bkup_dir)

    print('TRAIN DONE')