def test(args): # ---------- load model_real_cartoon ---------- # rc_e1 = E1(args.sep, int((args.resize / 64))) rc_e2 = E2(args.sep, int((args.resize / 64))) rc_decoder = Decoder(int((args.resize / 64))) if torch.cuda.is_available(): rc_e1 = rc_e1.cuda() rc_e2 = rc_e2.cuda() rc_decoder = rc_decoder.cuda() if args.load_rc != '': save_file = os.path.join(args.load_rc) load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder) rc_e1 = rc_e1.eval() rc_e2 = rc_e2.eval() rc_decoder = rc_decoder.eval() # ---------- load model_cartoon ---------- # c_e1 = E1(args.sep, int((args.resize / 64))) c_e2 = E2(args.sep, int((args.resize / 64))) c_decoder = Decoder(int((args.resize / 64))) if torch.cuda.is_available(): c_e1 = c_e1.cuda() c_e2 = c_e2.cuda() c_decoder = c_decoder.cuda() if args.load_c != '': save_file = os.path.join(args.load_c) load_model_for_eval(save_file, c_e1, c_e2, c_decoder) c_e1 = c_e1.eval() c_e2 = c_e2.eval() c_decoder = c_decoder.eval() # -------------- running -------------- # if not os.path.exists(args.out) and args.out != "": os.mkdir(args.out) # trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder) test_domA_cluster, test_domB_cluster = my_get_test_imgs(args) for idx, (test_domA, test_domB) in enumerate( list(zip(test_domA_cluster, test_domB_cluster))): trans(args, idx, test_domA, test_domB, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
def test(args): # ---------- load model_real_cartoon ---------- # rc_e1 = E1(args.sep, int((args.resize / 64))) rc_e2 = E2(args.sep, int((args.resize / 64))) rc_decoder = Decoder(int((args.resize / 64))) if torch.cuda.is_available(): rc_e1 = rc_e1.cuda() rc_e2 = rc_e2.cuda() rc_decoder = rc_decoder.cuda() if args.load_rc != '': save_file = os.path.join(args.load_rc) load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder) rc_e1 = rc_e1.eval() rc_e2 = rc_e2.eval() rc_decoder = rc_decoder.eval() # ---------- load model_cartoon ---------- # c_e1 = E1(args.sep, int((args.resize / 64))) c_e2 = E2(args.sep, int((args.resize / 64))) c_decoder = Decoder(int((args.resize / 64))) if torch.cuda.is_available(): c_e1 = c_e1.cuda() c_e2 = c_e2.cuda() c_decoder = c_decoder.cuda() if args.load_c != '': save_file = os.path.join(args.load_c) load_model_for_eval(save_file, c_e1, c_e2, c_decoder) c_e1 = c_e1.eval() c_e2 = c_e2.eval() c_decoder = c_decoder.eval() # -------------- running -------------- # if not os.path.exists(args.out) and args.out != "": os.mkdir(args.out) trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
def get_eval_model(load, sep, resize): e1 = E1(sep, int((resize / 64))) e2 = E2(sep, int((resize / 64))) decoder = Decoder(int((resize / 64))) if torch.cuda.is_available(): e1 = e1.cuda() e2 = e2.cuda() decoder = decoder.cuda() _iter = load_model_for_eval(load, e1, e2, decoder) e1 = e1.eval() e2 = e2.eval() decoder = decoder.eval() return e1, e2, decoder
def eval(args): e1 = E1(args.sep, int((args.resize / 64))) e2 = E2(args.sep, int((args.resize / 64))) decoder = Decoder(int((args.resize / 64))) if torch.cuda.is_available(): e1 = e1.cuda() e2 = e2.cuda() decoder = decoder.cuda() if args.load != '': save_file = os.path.join(args.load, 'checkpoint') _iter = load_model_for_eval(save_file, e1, e2, decoder) e1 = e1.eval() e2 = e2.eval() decoder = decoder.eval() if not os.path.exists(args.out) and args.out != "": os.mkdir(args.out) save_imgs(args, e1, e2, decoder, _iter)
def train(args): if not os.path.exists(args.out): os.makedirs(args.out) _iter = 0 comp_transformA = transforms.Compose([ transforms.CenterCrop(args.cropA), transforms.Resize(args.resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) comp_transformB = transforms.Compose([ transforms.CenterCrop(args.cropB), transforms.Resize(args.resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) domA_train = CustomDataset(os.path.join(args.root, 'trainA.txt'), transform=comp_transformA) domB_train = CustomDataset(os.path.join(args.root, 'trainB.txt'), transform=comp_transformB) A_label = torch.full((args.bs,), 1) B_label = torch.full((args.bs,), 0) B_separate = torch.full((args.bs, args.sep * (args.resize // 64) * (args.resize // 64)), 0) e1 = E1(args.sep, args.resize // 64) e2 = E2(args.sep, args.resize // 64) decoder = Decoder(args.resize // 64) disc = Disc(args.sep, args.resize // 64) mse = nn.MSELoss() bce = nn.BCELoss() if torch.cuda.is_available(): e1 = e1.cuda() e2 = e2.cuda() decoder = decoder.cuda() disc = disc.cuda() A_label = A_label.cuda() B_label = B_label.cuda() B_separate = B_separate.cuda() mse = mse.cuda() bce = bce.cuda() ae_params = list(e1.parameters()) + list(e2.parameters()) + list(decoder.parameters()) ae_optimizer = optim.Adam(ae_params, lr=args.lr, betas=(0.5, 0.999)) disc_params = disc.parameters() disc_optimizer = optim.Adam(disc_params, lr=args.disclr, betas=(0.5, 0.999)) if args.load != '': save_file = os.path.join(args.load, 'checkpoint') _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer) e1 = e1.train() e2 = e2.train() decoder = decoder.train() disc = disc.train() print('Started training...') while True: domA_loader = torch.utils.data.DataLoader(domA_train, batch_size=args.bs, shuffle=True, num_workers=6) domB_loader = torch.utils.data.DataLoader(domB_train, batch_size=args.bs, shuffle=True, num_workers=6) if _iter >= args.iters: break for domA_img, domB_img in zip(domA_loader, domB_loader): if domA_img.size(0) != args.bs or domB_img.size(0) != args.bs: break domA_img = Variable(domA_img) domB_img = Variable(domB_img) if torch.cuda.is_available(): domA_img = domA_img.cuda() domB_img = domB_img.cuda() domA_img = domA_img.view((-1, 3, args.resize, args.resize)) domB_img = domB_img.view((-1, 3, args.resize, args.resize)) ae_optimizer.zero_grad() A_common = e1(domA_img) A_separate = e2(domA_img) A_encoding = torch.cat([A_common, A_separate], dim=1) B_common = e1(domB_img) B_encoding = torch.cat([B_common, B_separate], dim=1) A_decoding = decoder(A_encoding) B_decoding = decoder(B_encoding) loss = mse(A_decoding, domA_img) + mse(B_decoding, domB_img) if args.discweight > 0: preds_A = disc(A_common) preds_B = disc(B_common) loss += args.discweight * (bce(preds_A, B_label) + bce(preds_B, B_label)) loss.backward() torch.nn.utils.clip_grad_norm_(ae_params, 5) ae_optimizer.step() if args.discweight > 0: disc_optimizer.zero_grad() A_common = e1(domA_img) B_common = e1(domB_img) disc_A = disc(A_common) disc_B = disc(B_common) loss2 = bce(disc_A, A_label) + bce(disc_B, B_label) loss2.backward() torch.nn.utils.clip_grad_norm_(disc_params, 5) disc_optimizer.step() if _iter % args.progress_iter == 0: print('Outfile: %s | Iteration %d | loss %.6f | loss1: %.6f | loss2: %.6f' % (args.out, _iter, loss+loss2, loss, loss2)) if _iter % args.display_iter == 0: e1 = e1.eval() e2 = e2.eval() decoder = decoder.eval() save_imgs(args, e1, e2, decoder, _iter) e1 = e1.train() e2 = e2.train() decoder = decoder.train() if _iter % args.save_iter == 0: save_file = os.path.join(args.out, 'checkpoint_%d' % _iter) save_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer, _iter) _iter += 1
def train(args): args.out = args.out + '_size_' + str(args.resize) if args.sep > 0: args.out = args.out + '_sep_' + str(args.sep) if args.disc_weight > 0: args.out = args.out + '_disc-weight_' + str(args.disc_weight) if args.disc_lr != 0.0002: args.out = args.out + '_disc-lr_' + str(args.disc_lr) _iter = 0 comp_transform = transforms.Compose([ transforms.CenterCrop(args.crop), transforms.Resize(args.resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) domA_train = CustomImageFolder(root=os.path.join(args.dataroot, 'trainA'), transform=comp_transform) domB_train = CustomImageFolder(root=os.path.join(args.dataroot, 'trainB'), transform=comp_transform) A_label = torch.full((args.bs, ), 1) B_label = torch.full((args.bs, ), 0) B_separate = torch.full( (args.bs, args.sep * (args.resize / 64) * (args.resize / 64)), 0) e1 = E1(args.sep, int((args.resize / 64))) e2 = E2(args.sep, int((args.resize / 64))) decoder = Decoder(int((args.resize / 64))) disc = Disc(args.sep, int((args.resize / 64))) mse = nn.MSELoss() bce = nn.BCELoss() if torch.cuda.is_available(): e1 = e1.cuda() e2 = e2.cuda() decoder = decoder.cuda() disc = disc.cuda() A_label = A_label.cuda() B_label = B_label.cuda() B_separate = B_separate.cuda() mse = mse.cuda() bce = bce.cuda() ae_params = list(e1.parameters()) + list(e2.parameters()) + list( decoder.parameters()) ae_optimizer = optim.Adam(ae_params, lr=args.lr, betas=(0.5, 0.999)) disc_params = disc.parameters() disc_optimizer = optim.Adam(disc_params, lr=args.disc_lr, betas=(0.5, 0.999)) if args.load != '': save_file = os.path.join(args.load, 'checkpoint') _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer) e1 = e1.train() e2 = e2.train() decoder = decoder.train() disc = disc.train() while True: domA_loader = torch.utils.data.DataLoader(domA_train, batch_size=args.bs, shuffle=True, num_workers=6) domB_loader = torch.utils.data.DataLoader(domB_train, batch_size=args.bs, shuffle=True, num_workers=6) if _iter >= args.iters: break for domA_img, domB_img in zip(domA_loader, domB_loader): domA_img = Variable(domA_img) domB_img = Variable(domB_img) if torch.cuda.is_available(): domA_img = domA_img.cuda() domB_img = domB_img.cuda() domA_img = domA_img.view((-1, 3, args.resize, args.resize)) domB_img = domB_img.view((-1, 3, args.resize, args.resize)) ae_optimizer.zero_grad() A_common = e1(domA_img) A_separate = e2(domA_img) A_encoding = torch.cat([A_common, A_separate], dim=1) B_common = e1(domB_img) B_encoding = torch.cat([B_common, B_separate], dim=1) A_decoding = decoder(A_encoding) B_decoding = decoder(B_encoding) loss = mse(A_decoding, domA_img) + mse(B_decoding, domB_img) if args.disc_weight > 0: preds_A = disc(A_common) preds_B = disc(B_common) loss += args.disc_weight * (bce(preds_A, B_label) + bce(preds_B, B_label)) loss.backward() torch.nn.utils.clip_grad_norm_(ae_params, 5) ae_optimizer.step() if args.disc_weight > 0: disc_optimizer.zero_grad() A_common = e1(domA_img) B_common = e1(domB_img) disc_A = disc(A_common) disc_B = disc(B_common) loss = bce(disc_A, A_label) + bce(disc_B, B_label) loss.backward() torch.nn.utils.clip_grad_norm_(disc_params, 5) disc_optimizer.step() if _iter % args.progress_iter == 0: print('Outfile: %s <<>> Iteration %d' % (args.out, _iter)) e1 = e1.eval() e2 = e2.eval() decoder = decoder.eval() if _iter % args.display_iter == 0: save_imgs(args, e1, e2, decoder) e1 = e1.train() e2 = e2.train() decoder = decoder.train() if _iter % args.save_iter == 0: save_file = os.path.join(args.out, 'checkpoint') save_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer, _iter) _iter += 1
def train(config): if not os.path.exists(config.out): os.makedirs(config.out) comp_transform = transforms.Compose([ transforms.CenterCrop(config.crop), transforms.Resize(config.resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) domain_a_train = CustomDataset(os.path.join(config.root, 'trainA.txt'), transform=comp_transform) domain_b_train = CustomDataset(os.path.join(config.root, 'trainB.txt'), transform=comp_transform) a_label = torch.full((config.bs,), 1) b_label = torch.full((config.bs,), 0) b_separate = torch.full((config.bs, config.sep, config.resize // (2 ** (config.n_blocks + 1)), config.resize // (2 ** (config.n_blocks + 1))), 0) # build networks e1 = E1(sep=config.sep, size=config.resize) e2 = E2(n_feats=config.n_tot_feats, sep=config.sep) decoder = Decoder(n_feats=config.n_tot_feats) disc = Disc(size=config.resize, sep=config.sep) rho_clipper = RhoClipper(0., 1.) mse = nn.MSELoss() bce = nn.BCELoss() if torch.cuda.is_available(): e1 = e1.cuda() e2 = e2.cuda() decoder = decoder.cuda() disc = disc.cuda() a_label = a_label.cuda() b_label = b_label.cuda() b_separate = b_separate.cuda() mse = mse.cuda() bce = bce.cuda() ae_params = list(e1.parameters()) + list(e2.parameters()) + list(decoder.parameters()) ae_optimizer = optim.Adam(ae_params, lr=config.lr, betas=(config.beta1, config.beta2), eps=config.eps) disc_params = disc.parameters() disc_optimizer = optim.Adam(disc_params, lr=config.d_lr, betas=(config.beta1, config.beta2), eps=config.eps) _iter: int = 0 if config.load != '': save_file = os.path.join(config.load, 'checkpoint') _iter = load_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer) e1 = e1.train() e2 = e2.train() decoder = decoder.train() disc = disc.train() print('[*] Started training...') while True: domain_a_loader = torch.utils.data.DataLoader(domain_a_train, batch_size=config.bs, shuffle=True, num_workers=config.n_threads) domain_b_loader = torch.utils.data.DataLoader(domain_b_train, batch_size=config.bs, shuffle=True, num_workers=config.n_threads) if _iter >= config.iters: break for domain_a_img, domain_b_img in zip(domain_a_loader, domain_b_loader): if domain_a_img.size(0) != config.bs or domain_b_img.size(0) != config.bs: break domain_a_img = Variable(domain_a_img) domain_b_img = Variable(domain_b_img) if torch.cuda.is_available(): domain_a_img = domain_a_img.cuda() domain_b_img = domain_b_img.cuda() domain_a_img = domain_a_img.view((-1, 3, config.resize, config.resize)) domain_b_img = domain_b_img.view((-1, 3, config.resize, config.resize)) ae_optimizer.zero_grad() a_common = e1(domain_a_img) a_separate = e2(domain_a_img) a_encoding = torch.cat([a_common, a_separate], dim=1) b_common = e1(domain_b_img) b_encoding = torch.cat([b_common, b_separate], dim=1) a_decoding = decoder(a_encoding) b_decoding = decoder(b_encoding) g_loss = mse(a_decoding, domain_a_img) + mse(b_decoding, domain_b_img) preds_a = disc(a_common) preds_b = disc(b_common) g_loss += config.adv_weight * (bce(preds_a, b_label) + bce(preds_b, b_label)) g_loss.backward() torch.nn.utils.clip_grad_norm_(ae_params, 5.) ae_optimizer.step() disc_optimizer.zero_grad() a_common = e1(domain_a_img) b_common = e1(domain_b_img) disc_a = disc(a_common) disc_b = disc(b_common) d_loss = bce(disc_a, a_label) + bce(disc_b, b_label) d_loss.backward() torch.nn.utils.clip_grad_norm_(disc_params, 5.) disc_optimizer.step() decoder.apply(rho_clipper) if _iter % config.progress_iter == 0: print('[*] [%07d/%07d] d_loss : %.4f, g_loss : %.4f' % (_iter, config.iters, d_loss, g_loss)) if _iter % config.display_iter == 0: e1 = e1.eval() e2 = e2.eval() decoder = decoder.eval() save_images(config, e1, e2, decoder, _iter) e1 = e1.train() e2 = e2.train() decoder = decoder.train() if _iter % config.save_iter == 0: save_file = os.path.join(config.out, 'checkpoint') save_model(save_file, e1, e2, decoder, ae_optimizer, disc, disc_optimizer, _iter) _iter += 1