コード例 #1
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    if args.mode == "train":
        args.img_datatype = find_img_datatype(args.train_img_dir)
    elif args.mode == "sample":
        args.img_datatype = find_img_datatype(args.src_dir)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_type=args.img_datatype,
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_type=args.img_datatype,
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_type=args.img_datatype,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
コード例 #2
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        sintel_path = "/home/tomstrident/datasets/"
        video_id = "temple_2"
        test_loader = getTestDatasetLoader(sintel_path, video_id)
        train_loader, eval_loader = get_loaderFC2(
            args.data_dir, args.style_dir, args.temp_dir, args.batch_size,
            args.num_workers, args.num_domains, args.mode)
        print("start training ...")
        print("args.num_domains:", args.num_domains)
        solver.train([train_loader, test_loader])
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        _, eval_loader = get_loaderFC2(args.data_dir, args.style_dir,
                                       args.temp_dir, args.batch_size,
                                       args.num_workers, args.num_domains,
                                       args.mode)
        print("len(eval_loader)", len(eval_loader))
        solver.evaluate(loader=eval_loader)
        #solver.eval_sintel()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
コード例 #3
0
ファイル: main.py プロジェクト: sallyy1/hairgan
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)
    if args.mode == 'train':
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers,
                                             dataset_dir=args.dataset_dir),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers,
                                             dataset_dir=args.dataset_dir),
                        val=get_val_loader(root=args.val_img_dir,
                                           img_size=args.img_size,
                                           batch_size=args.val_batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           dataset_dir=args.dataset_dir))
        solver.train(loaders)

    elif args.mode == 'sample':  ## 'styling_ref' 모드에 맞게 hyun 추가
        solver.sample()

        # hyun 추가
        #parsing(respth='./results/label/src', dspth= os.path.join(args.src_dir,trg_domain)) # parsing src_image
        #parsing(respth='./results/label/others', dspth= os.path.join(args.result_dir,trg_domain) # parsing fake_image
        #reconstruct() # 'styling' 모드

    elif args.mode == 'eval':
        fid_values, fid_mean = solver.evaluate()
        for key, value in fid_values.items():
            print(key, value)
    else:
        raise NotImplementedError
コード例 #4
0
ファイル: main.py プロジェクト: mrfwn/api-ts-pet
def main():

    solver = Solver()

    with open('imgNames.txt') as f:
        imgName = f.read()

    src = cv2.imread("tmp/uploads/" + imgName)
    ref = cv2.imread("src/stargan/assets/ref_" + str(randrange(7)) + ".jpg")

    res_img = solver.sample(src, ref)
    name_f = ''.join(
        random.choices(string.ascii_uppercase + string.digits, k=15)) + '.jpg'

    cv2.imwrite("tmp/" + name_f, res_img)

    with open('output.txt', 'w+') as f:
        f.write(name_f)
コード例 #5
0
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    #create csv file
    with open(args.loss_csv_path, 'wb') as csvfile:
        filewriter = csv.writer(csvfile,
                                delimiter=',',
                                quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)

    with open(args.loss_csv_path, 'a') as file:
        writer = csv.writer(file)
        writer.writerow([
            "epoch", "d_loss_z_trg", "d_loss_x_ref", "g_loss_z_trg",
            "g_loss_x_ref"
        ])

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        src_skt=get_train_loader(
                            root=args.train_sketch_img_dir,
                            which='source',
                            img_size=args.img_size,
                            batch_size=args.batch_size,
                            prob=args.randcrop_prob,
                            num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate()
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    else:
        raise NotImplementedError
コード例 #6
0
ファイル: main.py プロジェクト: staceysv/stargan-v2
def main(args):
    print(args)
    #wandb.init(project="stargan", entity="stacey", config=args, name=args.model_name)
    #cfg = wandb.config
    #cfg.update({"dataset" : "afhq", "type" : "train"})
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)
    elif args.mode == 'sample':
        assert len(subdirs(args.src_dir)) == args.num_domains
        assert len(subdirs(args.ref_dir)) == args.num_domains
        loaders = Munch(src=get_test_loader(root=args.src_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=args.ref_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.sample(loaders)
    elif args.mode == 'eval':
        solver.evaluate(args)
    elif args.mode == 'align':
        from core.wing import align_faces
        align_faces(args, args.inp_dir, args.out_dir)
    elif args.mode == 'custom':
        # override some default arguments
        wandb.init(project="stargan", config=args, name=args.model_name)
        # src or ref may each be a dir or an image
        # make temporary folders for images
        if os.path.isfile(args.custom_src):
            src_dir = "tmp_src"
            full_src = src_dir + "/src"
            if os.path.exists(src_dir):
                shutil.rmtree(src_dir)
            os.makedirs(full_src)
            shutil.copy2(args.custom_src, full_src)
            src_images = src_dir
        else:
            src_images = args.custom_src
        if os.path.isfile(args.custom_ref):
            ref_dir = "tmp_ref"
            full_ref = ref_dir + "/ref"
            if os.path.exists(ref_dir):
                shutil.rmtree(ref_dir)
            os.makedirs(full_ref)
            shutil.copy2(args.custom_ref, full_ref)
            if args.extend_domain:
                # make some extra domains
                for d in [ref_dir + "/ref2", ref_dir + "/ref3"]:
                    os.makedirs(d)
                    shutil.copy2(args.custom_ref, d)
            ref_images = ref_dir
        else:
            ref_images = args.custom_ref
        loaders = Munch(src=get_test_loader(root=src_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers),
                        ref=get_test_loader(root=ref_images,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers))
        solver.custom(loaders)
    else:
        raise NotImplementedError