Пример #1
0
    G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
    D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

    try:
        p_G = './saved_mask_models/G_latest.pth'
        p_D = './saved_mask_models/D_latest.pth'
        G.load_state_dict(torch.load(p_G, map_location=torch.device('cpu')), strict=False)
        D.load_state_dict(torch.load(p_D, map_location=torch.device('cpu')), strict=False)
        
        print('p_G : ',p_G)
        print('p_D : ',p_D)
        
    except Exception as e:
        print(e)
    
    dataset = FaceEmbed(['./train_datasets/Foreign-2020-09-06/'], same_prob=0.35)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)


    MSE = torch.nn.MSELoss()
    L1 = torch.nn.L1Loss()

    # print(torch.backends.cudnn.benchmark)
    torch.backends.cudnn.benchmark = True
    for epoch in range(0, max_epoch):
        # torch.cuda.empty_cache()
        for iteration, data in enumerate(dataloader):
            start_time = time.time()
            Xs, Xt, same_person = data
            Xs = Xs.to(device)
            Xt = Xt.to(device)
Пример #2
0
    G.load_state_dict(torch.load(os.path.join(args.saved_models,
                                              'G_latest.pth'),
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load(os.path.join(args.saved_models,
                                              'D_latest.pth'),
                                 map_location=torch.device('cpu')),
                      strict=False)
except Exception as e:
    print(e)

# if not fine_tune_with_identity:
# dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.5)
# else:
# dataset = With_Identity('../washed_img/', 0.8)
dataset = FaceEmbed([args.images_path], same_prob=args.same_prob)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0,
                        drop_last=True)

MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
        return torch.relu(1 - X).mean()
    else:
Пример #3
0
    print(e)
try:
    with open('./saved_models/AEI_niter.pkl', 'rb') as f:
        min_iter = pickle.load(f)
except Exception as e:
    print(e)
writer = SummaryWriter('runs/FaceShifterAEInet', purge_step=min_iter)

TrainFaceSources = [
    '/home/olivier/Images/FaceShifter/celeba-256/',
    '/home/olivier/Images/FaceShifter/Perso/',
    '/home/olivier/Images/FaceShifter/VGGFaceTrain/',
    '/home/olivier/Images/FaceShifter/FFHQ/',
    '/home/olivier/Images/FaceShifter/Others/'
]
train_dataset = FaceEmbed(TrainFaceSources, same_prob=0.2)

train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=True)
train_loader = iter(train_dataloader)

MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
Пример #4
0
    G.load_state_dict(torch.load('./saved_models/G_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load('./saved_models/D_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
except Exception as e:
    print(e)

# if not fine_tune_with_identity:
# dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.5)
# else:
# dataset = With_Identity('../washed_img/', 0.8)
dataset = FaceEmbed([
    '../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/',
    '../stars_256_0.85/'
],
                    same_prob=0.8)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0,
                        drop_last=True)

MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
Пример #5
0
opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

try:
    G.load_state_dict(torch.load('./saved_models/G_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load('./saved_models/D_latest.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
except Exception as e:
    print(e)

dataset = FaceEmbed(['../celeba_64/'], same_prob=0.8)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0,
                        drop_last=True)

MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
        return torch.relu(1 - X).mean()
    else:
Пример #6
0
arcface.load_state_dict(torch.load('./face_modules/model_ir_se50.pth', map_location=device), strict=False)

opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999), weight_decay=1e-4)
opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999), weight_decay=1e-4)

G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

try:
    G.load_state_dict(torch.load('./saved_models/G_latest.pth', map_location=torch.device('cpu')), strict=False)
    D.load_state_dict(torch.load('./saved_models/D_latest.pth', map_location=torch.device('cpu')), strict=False)
except Exception as e:
    print(e)


dataset = FaceEmbed([dataset_path], same_prob=0.5)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)


MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
        return torch.relu(1-X)
    else:
        return torch.relu(X+1)

Пример #7
0
                          strict=False)

        print('Finetune_G_Model : ', Finetune_G_Model)
        print('Finetune_D_Model : ', Finetune_D_Model)

    except Exception as e:
        print(e)

    # if not fine_tune_with_identity:
    # dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.5)
    # else:
    # dataset = With_Identity('../washed_img/', 0.8)
    # dataset = FaceEmbed(['../celeb-aligned-256_0.85/', '../ffhq_256_0.85/', '../vgg_256_0.85/', '../stars_256_0.85/'], same_prob=0.8)

    dataset = FaceEmbed(['./train_datasets/Foreign-2020-09-06/'],
                        same_prob=0.5,
                        Flag_256=Flag_256)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=3,
                            drop_last=True)

    MSE = torch.nn.MSELoss()
    L1 = torch.nn.L1Loss()

    # print(torch.backends.cudnn.benchmark)
    torch.backends.cudnn.benchmark = True
    for epoch in range(0, max_epoch):
        # torch.cuda.empty_cache()
        for iteration, data in enumerate(dataloader):
Пример #8
0
try:
    print('load pretrained model')
    # G.load_state_dict(torch.load('./saved_models/G_latest_Heonozis_original.pth', map_location=torch.device('cpu')), strict=False)
    # D.load_state_dict(torch.load('./saved_models/D_latest_Heonozis_original.pth', map_location=torch.device('cpu')), strict=False)

    G.load_state_dict(torch.load('./saved_models/G_train.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
    D.load_state_dict(torch.load('./saved_models/D_train.pth',
                                 map_location=torch.device('cpu')),
                      strict=False)
except Exception as e:
    print(e)

dataset = FaceEmbed(['../datasets/img_align_celeba_64/'], same_prob=0.2)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0,
                        drop_last=True)

MSE = torch.nn.MSELoss()
L1 = torch.nn.L1Loss()


def hinge_loss(X, positive=True):
    if positive:
        return torch.relu(1 - X).mean()
    else: