Beispiel #1
0
def test(args):
    # ---------- load model_real_cartoon ---------- #

    rc_e1 = E1(args.sep, int((args.resize / 64)))
    rc_e2 = E2(args.sep, int((args.resize / 64)))
    rc_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        rc_e1 = rc_e1.cuda()
        rc_e2 = rc_e2.cuda()
        rc_decoder = rc_decoder.cuda()

    if args.load_rc != '':
        save_file = os.path.join(args.load_rc)
        load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder)

    rc_e1 = rc_e1.eval()
    rc_e2 = rc_e2.eval()
    rc_decoder = rc_decoder.eval()

    # ---------- load model_cartoon ---------- #

    c_e1 = E1(args.sep, int((args.resize / 64)))
    c_e2 = E2(args.sep, int((args.resize / 64)))
    c_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        c_e1 = c_e1.cuda()
        c_e2 = c_e2.cuda()
        c_decoder = c_decoder.cuda()

    if args.load_c != '':
        save_file = os.path.join(args.load_c)
        load_model_for_eval(save_file, c_e1, c_e2, c_decoder)

    c_e1 = c_e1.eval()
    c_e2 = c_e2.eval()
    c_decoder = c_decoder.eval()

    # -------------- running -------------- #

    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)


#     trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
    test_domA_cluster, test_domB_cluster = my_get_test_imgs(args)
    for idx, (test_domA, test_domB) in enumerate(
            list(zip(test_domA_cluster, test_domB_cluster))):
        trans(args, idx, test_domA, test_domB, rc_e1, rc_e2, rc_decoder, c_e1,
              c_e2, c_decoder)
Beispiel #2
0
def test(args):
    # ---------- load model_real_cartoon ---------- #
    
    rc_e1 = E1(args.sep, int((args.resize / 64)))
    rc_e2 = E2(args.sep, int((args.resize / 64)))
    rc_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        rc_e1 = rc_e1.cuda()
        rc_e2 = rc_e2.cuda()
        rc_decoder = rc_decoder.cuda()

    if args.load_rc != '':
        save_file = os.path.join(args.load_rc)
        load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder)

    rc_e1 = rc_e1.eval()
    rc_e2 = rc_e2.eval()
    rc_decoder = rc_decoder.eval()
    
    # ---------- load model_cartoon ---------- #
    
    c_e1 = E1(args.sep, int((args.resize / 64)))
    c_e2 = E2(args.sep, int((args.resize / 64)))
    c_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        c_e1 = c_e1.cuda()
        c_e2 = c_e2.cuda()
        c_decoder = c_decoder.cuda()

    if args.load_c != '':
        save_file = os.path.join(args.load_c)
        load_model_for_eval(save_file, c_e1, c_e2, c_decoder)

    c_e1 = c_e1.eval()
    c_e2 = c_e2.eval()
    c_decoder = c_decoder.eval()
    
    # -------------- running -------------- #
    
    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)

    trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
Beispiel #3
0
def get_eval_model(load, sep, resize):
    e1 = E1(sep, int((resize / 64)))
    e2 = E2(sep, int((resize / 64)))
    decoder = Decoder(int((resize / 64)))
    
    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()
    
    _iter = load_model_for_eval(load, e1, e2, decoder)
    
    e1 = e1.eval()
    e2 = e2.eval()
    decoder = decoder.eval()
    return e1, e2, decoder
def eval(args):
    e1 = E1(args.sep, int((args.resize / 64)))
    e2 = E2(args.sep, int((args.resize / 64)))
    decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()

    if args.load != '':
        save_file = os.path.join(args.load, 'checkpoint')
        _iter = load_model_for_eval(save_file, e1, e2, decoder)

    e1 = e1.eval()
    e2 = e2.eval()
    decoder = decoder.eval()

    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)

    save_imgs(args, e1, e2, decoder, _iter)
Beispiel #5
0
def train(args):
    if not os.path.exists(args.out):
        os.makedirs(args.out)

    _iter = 0

    comp_transformA = transforms.Compose([
        transforms.CenterCrop(args.cropA),
        transforms.Resize(args.resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    comp_transformB = transforms.Compose([
        transforms.CenterCrop(args.cropB),
        transforms.Resize(args.resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    domA_train = CustomDataset(os.path.join(args.root, 'trainA.txt'), transform=comp_transformA)
    domB_train = CustomDataset(os.path.join(args.root, 'trainB.txt'), transform=comp_transformB)

    A_label = torch.full((args.bs,), 1)
    B_label = torch.full((args.bs,), 0)
    B_separate = torch.full((args.bs, args.sep * (args.resize // 64) * (args.resize // 64)), 0)

    e1 = E1(args.sep, args.resize // 64)
    e2 = E2(args.sep, args.resize // 64)
    decoder = Decoder(args.resize // 64)
    disc = Disc(args.sep, args.resize // 64)

    mse = nn.MSELoss()
    bce = nn.BCELoss()

    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()
        disc = disc.cuda()

        A_label = A_label.cuda()
        B_label = B_label.cuda()
        B_separate = B_separate.cuda()

        mse = mse.cuda()
        bce = bce.cuda()

    ae_params = list(e1.parameters()) + list(e2.parameters()) + list(decoder.parameters())
    ae_optimizer = optim.Adam(ae_params, lr=args.lr, betas=(0.5, 0.999))

    disc_params = disc.parameters()
    disc_optimizer = optim.Adam(disc_params, lr=args.disclr, betas=(0.5, 0.999))

    if args.load != '':
        save_file = os.path.join(args.load, 'checkpoint')
        _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer)

    e1 = e1.train()
    e2 = e2.train()
    decoder = decoder.train()
    disc = disc.train()

    print('Started training...')
    while True:
        domA_loader = torch.utils.data.DataLoader(domA_train, batch_size=args.bs,
                                                  shuffle=True, num_workers=6)
        domB_loader = torch.utils.data.DataLoader(domB_train, batch_size=args.bs,
                                                  shuffle=True, num_workers=6)
        if _iter >= args.iters:
            break

        for domA_img, domB_img in zip(domA_loader, domB_loader):
            if domA_img.size(0) != args.bs or domB_img.size(0) != args.bs:
                break

            domA_img = Variable(domA_img)
            domB_img = Variable(domB_img)

            if torch.cuda.is_available():
                domA_img = domA_img.cuda()
                domB_img = domB_img.cuda()

            domA_img = domA_img.view((-1, 3, args.resize, args.resize))
            domB_img = domB_img.view((-1, 3, args.resize, args.resize))

            ae_optimizer.zero_grad()

            A_common = e1(domA_img)
            A_separate = e2(domA_img)
            A_encoding = torch.cat([A_common, A_separate], dim=1)

            B_common = e1(domB_img)
            B_encoding = torch.cat([B_common, B_separate], dim=1)

            A_decoding = decoder(A_encoding)
            B_decoding = decoder(B_encoding)

            loss = mse(A_decoding, domA_img) + mse(B_decoding, domB_img)

            if args.discweight > 0:
                preds_A = disc(A_common)
                preds_B = disc(B_common)
                loss += args.discweight * (bce(preds_A, B_label) + bce(preds_B, B_label))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(ae_params, 5)
            ae_optimizer.step()

            if args.discweight > 0:
                disc_optimizer.zero_grad()

                A_common = e1(domA_img)
                B_common = e1(domB_img)

                disc_A = disc(A_common)
                disc_B = disc(B_common)

                loss2 = bce(disc_A, A_label) + bce(disc_B, B_label)

                loss2.backward()
                torch.nn.utils.clip_grad_norm_(disc_params, 5)
                disc_optimizer.step()

            if _iter % args.progress_iter == 0:
                print('Outfile: %s | Iteration %d | loss %.6f | loss1: %.6f | loss2: %.6f' % (args.out, _iter, loss+loss2, loss, loss2))

            if _iter % args.display_iter == 0:
                e1 = e1.eval()
                e2 = e2.eval()
                decoder = decoder.eval()

                save_imgs(args, e1, e2, decoder, _iter)

                e1 = e1.train()
                e2 = e2.train()
                decoder = decoder.train()

            if _iter % args.save_iter == 0:
                save_file = os.path.join(args.out, 'checkpoint_%d' % _iter)
                save_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer, _iter)
                

            _iter += 1
Beispiel #6
0
def train(args):
    args.out = args.out + '_size_' + str(args.resize)
    if args.sep > 0:
        args.out = args.out + '_sep_' + str(args.sep)
    if args.disc_weight > 0:
        args.out = args.out + '_disc-weight_' + str(args.disc_weight)
    if args.disc_lr != 0.0002:
        args.out = args.out + '_disc-lr_' + str(args.disc_lr)

    _iter = 0

    comp_transform = transforms.Compose([
        transforms.CenterCrop(args.crop),
        transforms.Resize(args.resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    domA_train = CustomImageFolder(root=os.path.join(args.dataroot, 'trainA'),
                                   transform=comp_transform)
    domB_train = CustomImageFolder(root=os.path.join(args.dataroot, 'trainB'),
                                   transform=comp_transform)

    A_label = torch.full((args.bs, ), 1)
    B_label = torch.full((args.bs, ), 0)
    B_separate = torch.full(
        (args.bs, args.sep * (args.resize / 64) * (args.resize / 64)), 0)

    e1 = E1(args.sep, int((args.resize / 64)))
    e2 = E2(args.sep, int((args.resize / 64)))
    decoder = Decoder(int((args.resize / 64)))
    disc = Disc(args.sep, int((args.resize / 64)))

    mse = nn.MSELoss()
    bce = nn.BCELoss()

    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()
        disc = disc.cuda()

        A_label = A_label.cuda()
        B_label = B_label.cuda()
        B_separate = B_separate.cuda()

        mse = mse.cuda()
        bce = bce.cuda()

    ae_params = list(e1.parameters()) + list(e2.parameters()) + list(
        decoder.parameters())
    ae_optimizer = optim.Adam(ae_params, lr=args.lr, betas=(0.5, 0.999))

    disc_params = disc.parameters()
    disc_optimizer = optim.Adam(disc_params,
                                lr=args.disc_lr,
                                betas=(0.5, 0.999))

    if args.load != '':
        save_file = os.path.join(args.load, 'checkpoint')
        _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc,
                           disc_optimizer)

    e1 = e1.train()
    e2 = e2.train()
    decoder = decoder.train()
    disc = disc.train()

    while True:
        domA_loader = torch.utils.data.DataLoader(domA_train,
                                                  batch_size=args.bs,
                                                  shuffle=True,
                                                  num_workers=6)
        domB_loader = torch.utils.data.DataLoader(domB_train,
                                                  batch_size=args.bs,
                                                  shuffle=True,
                                                  num_workers=6)
        if _iter >= args.iters:
            break
        for domA_img, domB_img in zip(domA_loader, domB_loader):
            domA_img = Variable(domA_img)
            domB_img = Variable(domB_img)

            if torch.cuda.is_available():
                domA_img = domA_img.cuda()
                domB_img = domB_img.cuda()

            domA_img = domA_img.view((-1, 3, args.resize, args.resize))
            domB_img = domB_img.view((-1, 3, args.resize, args.resize))

            ae_optimizer.zero_grad()

            A_common = e1(domA_img)
            A_separate = e2(domA_img)
            A_encoding = torch.cat([A_common, A_separate], dim=1)

            B_common = e1(domB_img)
            B_encoding = torch.cat([B_common, B_separate], dim=1)

            A_decoding = decoder(A_encoding)
            B_decoding = decoder(B_encoding)

            loss = mse(A_decoding, domA_img) + mse(B_decoding, domB_img)

            if args.disc_weight > 0:
                preds_A = disc(A_common)
                preds_B = disc(B_common)
                loss += args.disc_weight * (bce(preds_A, B_label) +
                                            bce(preds_B, B_label))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(ae_params, 5)
            ae_optimizer.step()

            if args.disc_weight > 0:
                disc_optimizer.zero_grad()

                A_common = e1(domA_img)
                B_common = e1(domB_img)

                disc_A = disc(A_common)
                disc_B = disc(B_common)

                loss = bce(disc_A, A_label) + bce(disc_B, B_label)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(disc_params, 5)
                disc_optimizer.step()

            if _iter % args.progress_iter == 0:
                print('Outfile: %s <<>> Iteration %d' % (args.out, _iter))

            e1 = e1.eval()
            e2 = e2.eval()
            decoder = decoder.eval()

            if _iter % args.display_iter == 0:
                save_imgs(args, e1, e2, decoder)

            e1 = e1.train()
            e2 = e2.train()
            decoder = decoder.train()

            if _iter % args.save_iter == 0:
                save_file = os.path.join(args.out, 'checkpoint')
                save_model(save_file, e1, e2, decoder, ae_optimizer, disc,
                           disc_optimizer, _iter)

            _iter += 1
Beispiel #7
0
def train(config):
    if not os.path.exists(config.out):
        os.makedirs(config.out)

    comp_transform = transforms.Compose([
        transforms.CenterCrop(config.crop),
        transforms.Resize(config.resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    domain_a_train = CustomDataset(os.path.join(config.root, 'trainA.txt'), transform=comp_transform)
    domain_b_train = CustomDataset(os.path.join(config.root, 'trainB.txt'), transform=comp_transform)

    a_label = torch.full((config.bs,), 1)
    b_label = torch.full((config.bs,), 0)
    b_separate = torch.full((config.bs,
                             config.sep,
                             config.resize // (2 ** (config.n_blocks + 1)),
                             config.resize // (2 ** (config.n_blocks + 1))), 0)

    # build networks
    e1 = E1(sep=config.sep, size=config.resize)
    e2 = E2(n_feats=config.n_tot_feats, sep=config.sep)
    decoder = Decoder(n_feats=config.n_tot_feats)
    disc = Disc(size=config.resize, sep=config.sep)
    rho_clipper = RhoClipper(0., 1.)

    mse = nn.MSELoss()
    bce = nn.BCELoss()

    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()
        disc = disc.cuda()

        a_label = a_label.cuda()
        b_label = b_label.cuda()
        b_separate = b_separate.cuda()

        mse = mse.cuda()
        bce = bce.cuda()

    ae_params = list(e1.parameters()) + list(e2.parameters()) + list(decoder.parameters())
    ae_optimizer = optim.Adam(ae_params, lr=config.lr,
                              betas=(config.beta1, config.beta2), eps=config.eps)

    disc_params = disc.parameters()
    disc_optimizer = optim.Adam(disc_params, lr=config.d_lr,
                                betas=(config.beta1, config.beta2), eps=config.eps)

    _iter: int = 0
    if config.load != '':
        save_file = os.path.join(config.load, 'checkpoint')
        _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer)

    e1 = e1.train()
    e2 = e2.train()
    decoder = decoder.train()
    disc = disc.train()

    print('[*] Started training...')
    while True:
        domain_a_loader = torch.utils.data.DataLoader(domain_a_train, batch_size=config.bs,
                                                      shuffle=True, num_workers=config.n_threads)
        domain_b_loader = torch.utils.data.DataLoader(domain_b_train, batch_size=config.bs,
                                                      shuffle=True, num_workers=config.n_threads)
        if _iter >= config.iters:
            break

        for domain_a_img, domain_b_img in zip(domain_a_loader, domain_b_loader):
            if domain_a_img.size(0) != config.bs or domain_b_img.size(0) != config.bs:
                break

            domain_a_img = Variable(domain_a_img)
            domain_b_img = Variable(domain_b_img)

            if torch.cuda.is_available():
                domain_a_img = domain_a_img.cuda()
                domain_b_img = domain_b_img.cuda()

            domain_a_img = domain_a_img.view((-1, 3, config.resize, config.resize))
            domain_b_img = domain_b_img.view((-1, 3, config.resize, config.resize))

            ae_optimizer.zero_grad()

            a_common = e1(domain_a_img)
            a_separate = e2(domain_a_img)
            a_encoding = torch.cat([a_common, a_separate], dim=1)

            b_common = e1(domain_b_img)
            b_encoding = torch.cat([b_common, b_separate], dim=1)

            a_decoding = decoder(a_encoding)
            b_decoding = decoder(b_encoding)

            g_loss = mse(a_decoding, domain_a_img) + mse(b_decoding, domain_b_img)

            preds_a = disc(a_common)
            preds_b = disc(b_common)
            g_loss += config.adv_weight * (bce(preds_a, b_label) + bce(preds_b, b_label))

            g_loss.backward()
            torch.nn.utils.clip_grad_norm_(ae_params, 5.)
            ae_optimizer.step()

            disc_optimizer.zero_grad()

            a_common = e1(domain_a_img)
            b_common = e1(domain_b_img)

            disc_a = disc(a_common)
            disc_b = disc(b_common)

            d_loss = bce(disc_a, a_label) + bce(disc_b, b_label)

            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(disc_params, 5.)
            disc_optimizer.step()

            decoder.apply(rho_clipper)

            if _iter % config.progress_iter == 0:
                print('[*] [%07d/%07d] d_loss : %.4f, g_loss : %.4f' %
                      (_iter, config.iters, d_loss, g_loss))

            if _iter % config.display_iter == 0:
                e1 = e1.eval()
                e2 = e2.eval()
                decoder = decoder.eval()

                save_images(config, e1, e2, decoder, _iter)

                e1 = e1.train()
                e2 = e2.train()
                decoder = decoder.train()

            if _iter % config.save_iter == 0:
                save_file = os.path.join(config.out, 'checkpoint')
                save_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer, _iter)

            _iter += 1