Exemple #1
0
def run(input_file, model_type):
    f_id = str(uuid.uuid4())
    fname = secure_filename(input_file.filename)

    # save image to upload folder
    os.makedirs(os.path.join(UPLOAD_FOLDER, f_id), exist_ok=True)

    #update args
    args = update_args(default_args, f_id)
    torch.manual_seed(args.seed)

    #allocate solver and update args.ref_dir
    if model_type == "Human Face":
        solver = CelebA_HQ
        args.ref_dir = 'assets/representative/celeba_hq/ref'

        # human face crop
        pil_im = Image.open(input_file.stream).convert('RGB')
        im = np.uint8(pil_im)
        face_im = detect_face(copy.copy(im))

        # if can not detect face
        if type(face_im) == bool:
            return 'no face'

        Image.fromarray(face_im).save(os.path.join(UPLOAD_FOLDER, f_id, fname))
    else:
        solver = AFHQ
        args.ref_dir = 'assets/representative/afhq/ref'

        input_file.save(os.path.join(UPLOAD_FOLDER, f_id, fname))

    # align image
    align_faces(args, args.inp_dir, args.out_dir)

    #define loaders
    loaders = Munch(src=get_test_loader(root=args.src_dir,
                                        img_size=args.img_size,
                                        batch_size=args.val_batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers),
                    ref=get_test_loader(root=args.ref_dir,
                                        img_size=args.img_size,
                                        batch_size=args.val_batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers))

    #generate image
    solver.sample(loaders, args.result_dir)

    #read image
    path = os.path.join(args.result_dir, 'reference.jpg')
    with open(path, 'rb') as f:
        data = f.read()
    result = io.BytesIO(data)

    #remove image data
    remove_image(args)

    return result
Exemple #2
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    if args.mode == "train":
        args.img_datatype = find_img_datatype(args.train_img_dir)
    elif args.mode == "sample":
        args.img_datatype = find_img_datatype(args.src_dir)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_type=args.img_datatype,
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_type=args.img_datatype,
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
def main(args):

    pprint(vars(args))

    # 设置随机数种子
    start_pro = fluid.default_startup_program()
    default_pro = fluid.default_main_program()
    start_pro.random_seed = args.seed
    default_pro.random_seed = args.seed

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))

        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Exemple #4
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        sintel_path = "/home/tomstrident/datasets/"
        video_id = "temple_2"
        test_loader = getTestDatasetLoader(sintel_path, video_id)
        train_loader, eval_loader = get_loaderFC2(
            args.data_dir, args.style_dir, args.temp_dir, args.batch_size,
            args.num_workers, args.num_domains, args.mode)
        print("start training ...")
        print("args.num_domains:", args.num_domains)
        solver.train([train_loader, test_loader])
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        _, eval_loader = get_loaderFC2(args.data_dir, args.style_dir,
                                       args.temp_dir, args.batch_size,
                                       args.num_workers, args.num_domains,
                                       args.mode)
        print("len(eval_loader)", len(eval_loader))
        solver.evaluate(loader=eval_loader)
        #solver.eval_sintel()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Exemple #5
0
def main(args):
    print(args)
    cudnn.benchmark = True
    if args.mode == 'train':
        torch.manual_seed(args.seed)

    solver = Solver(args)

    transform = transforms.Compose([
        transforms.Resize([args.img_size, args.img_size]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        if args.resume_iter > 0:
            solver._load_checkpoint(args.resume_iter)
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'eval':
        solver.evaluate()

    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)

    elif args.mode == 'inter':  # interpolation
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0).to(device))
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None
        y1 = torch.tensor([args.y1]).long().cuda()
        y2 = torch.tensor([args.y2]).long().cuda()
        outputs = interpolations(nets_ema,
                                 args.latent_dim,
                                 image,
                                 masks,
                                 lerp_step=0.1,
                                 y1=y1,
                                 y2=y2,
                                 lerp_mode=args.lerp_mode)
        path = os.path.join(save_dir, image_name)
        vutils.save_image(outputs.data, path, padding=0)

    elif args.mode == 'test':
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0)).cuda()
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None

        image_ref = None
        if args.test_mode == 'reference':
            image_ref = Variable(
                transform(Image.open(
                    args.input_ref).convert("RGB")).unsqueeze(0)).cuda()

        fake = test_single(nets_ema, image, masks, args.latent_dim, image_ref,
                           args.target_domain, args.single_mode)
        fake = torch.clamp(fake * 0.5 + 0.5, 0, 1)
        path = os.path.join(save_dir, image_name)
        vutils.save_image(fake.data, path, padding=0)

    elif args.mode == 'video':
        save_dir = args.save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        solver._load_checkpoint(args.resume_iter)
        nets_ema = solver.nets_ema

        image_name = os.path.basename(args.input)
        image = Variable(
            transform(Image.open(
                args.input).convert('RGB')).unsqueeze(0)).cuda()
        masks = nets_ema.fan.get_heatmap(image) if args.w_hpf > 0 else None

        y1 = torch.tensor([args.y1]).long().cuda()
        y2 = torch.tensor([args.y2]).long().cuda()
        outputs = interpolations_loop(nets_ema,
                                      args.latent_dim,
                                      image,
                                      masks,
                                      lerp_step=0.02,
                                      y1=y1,
                                      y2=y2,
                                      lerp_mode=args.lerp_mode)
        outputs = torch.cat(outputs)
        outputs = tensor2ndarray255(outputs)
        path = os.path.join(save_dir, '{}-video.mp4'.format(image_name))
        save_video(path, outputs)

    else:
        raise NotImplementedError
Exemple #6
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    #create csv file
    with open(args.loss_csv_path, 'wb') as csvfile:
        filewriter = csv.writer(csvfile,
                                delimiter=',',
                                quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)

    with open(args.loss_csv_path, 'a') as file:
        writer = csv.writer(file)
        writer.writerow([
            "epoch", "d_loss_z_trg", "d_loss_x_ref", "g_loss_z_trg",
            "g_loss_x_ref"
        ])

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        src_skt=get_train_loader(
                            root=args.train_sketch_img_dir,
                            which='source',
                            img_size=args.img_size,
                            batch_size=args.batch_size,
                            prob=args.randcrop_prob,
                            num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
Exemple #7
0
def main(args):
    print(args)
    #wandb.init(project="stargan", entity="stacey", config=args, name=args.model_name)
    #cfg = wandb.config
    #cfg.update({"dataset" : "afhq", "type" : "train"})
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate(args)
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    elif args.mode == 'custom':
        # override some default arguments
        wandb.init(project="stargan", config=args, name=args.model_name)
        # src or ref may each be a dir or an image
        # make temporary folders for images
        if os.path.isfile(args.custom_src):
            src_dir = "tmp_src"
            full_src = src_dir + "/src"
            if os.path.exists(src_dir):
                shutil.rmtree(src_dir)
            os.makedirs(full_src)
            shutil.copy2(args.custom_src, full_src)
            src_images = src_dir
        else:
            src_images = args.custom_src
        if os.path.isfile(args.custom_ref):
            ref_dir = "tmp_ref"
            full_ref = ref_dir + "/ref"
            if os.path.exists(ref_dir):
                shutil.rmtree(ref_dir)
            os.makedirs(full_ref)
            shutil.copy2(args.custom_ref, full_ref)
            if args.extend_domain:
                # make some extra domains
                for d in [ref_dir + "/ref2", ref_dir + "/ref3"]:
                    os.makedirs(d)
                    shutil.copy2(args.custom_ref, d)
            ref_images = ref_dir
        else:
            ref_images = args.custom_ref
        loaders = Munch(src=get_test_loader(root=src_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=ref_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.custom(loaders)
    else:
        raise NotImplementedError
Exemple #8
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)
    print('train')
    if args.mode == 'train':
        # assert len(subdirs(args.train_img_dir)) == args.num_domains
        # assert len(subdirs(args.val_img_dir)) == args.num_domains
        if args.loss == 'arcface':
            loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                                 train_data=args.dataset,
                                                 which='source',
                                                 img_size=args.img_size,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 prob=args.randcrop_prob,
                                                 num_workers=args.num_workers),
                            val=get_test_loader(root=args.train_img_dir,
                                                train_data=args.dataset,
                                                img_size=args.img_size,
                                                batch_size=args.val_batch_size,
                                                shuffle=False,
                                                num_workers=args.num_workers))
        elif args.loss == 'perceptual':
            loaders = Munch(
                src=get_train_loader_vgg(root=args.train_img_dir,
                                         train_data=args.dataset,
                                         which='source',
                                         img_size=args.img_size,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         prob=args.randcrop_prob,
                                         num_workers=args.num_workers),
                val=get_test_loader_vgg(root=args.train_img_dir,
                                        train_data=args.dataset,
                                        img_size=args.img_size,
                                        batch_size=args.val_batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers))

        solver.train(loaders)
    elif args.mode == 'sample':
        # assert len(subdirs(args.src_dir)) == args.num_domains
        # assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(
            src=get_test_loader_vgg(root=args.train_img_dir,
                                    train_data=args.dataset,
                                    img_size=args.img_size,
                                    batch_size=args.val_batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)
            # ref=get_test_loader(root=args.ref_dir,
            #                     img_size=args.img_size,
            #                     batch_size=args.val_batch_size,
            #                     shuffle=False,
            #                     num_workers=args.num_workers)
        )
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError