def train(epochs, interval, batchsize, validsize, data_path, sketch_path, extension, img_size, outdir, modeldir, gen_learning_rate, dis_learning_rate, beta1, beta2): # Dataset Definition dataset = IllustDataset(data_path, sketch_path, extension) c_valid, l_valid = dataset.valid(validsize) print(dataset) collator = LineCollator(img_size) # Model & Optimizer Definition model = Style2Paint() model.cuda() model.train() gen_opt = torch.optim.Adam(model.parameters(), lr=gen_learning_rate, betas=(beta1, beta2)) discriminator = Discriminator() discriminator.cuda() discriminator.train() dis_opt = torch.optim.Adam(discriminator.parameters(), lr=dis_learning_rate, betas=(beta1, beta2)) vgg = Vgg19(requires_grad=False) vgg.cuda() vgg.eval() # Loss function definition lossfunc = Style2paintsLossCalculator() # Visualizer definition visualizer = Visualizer() iteration = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, collate_fn=collator, drop_last=True) progress_bar = tqdm(dataloader) for index, data in enumerate(progress_bar): iteration += 1 jit, war, line = data # Discriminator update y = model(line, war) loss = lossfunc.adversarial_disloss(discriminator, y.detach(), jit) dis_opt.zero_grad() loss.backward() dis_opt.step() # Generator update y = model(line, war) loss = lossfunc.adversarial_genloss(discriminator, y) loss += 10.0 * lossfunc.content_loss(y, jit) loss += lossfunc.style_and_perceptual_loss(vgg, y, jit) gen_opt.zero_grad() loss.backward() gen_opt.step() if iteration % interval == 1: torch.save(model.state_dict(), f"{modeldir}/model_{iteration}.pt") with torch.no_grad(): y = model(l_valid, c_valid) c = c_valid.detach().cpu().numpy() l = l_valid.detach().cpu().numpy() y = y.detach().cpu().numpy() visualizer(l, c, y, outdir, iteration, validsize) print(f"iteration: {iteration} Loss: {loss.data}")
def train(epochs, interval, batchsize, validsize, data_path, sketch_path, extension, img_size, outdir, modeldir, learning_rate): # Dataset Definition dataset = IllustDataset(data_path, sketch_path, extension) c_valid, l_valid = dataset.valid(validsize) print(dataset) collator = LineCollator(img_size) # Model & Optimizer Definition model = Style2Paint(attn_type="adain") model.cuda() model.train() gen_opt = torch.optim.Adam(model.parameters(), lr=learning_rate) discriminator = Discriminator() discriminator.cuda() discriminator.train() dis_opt = torch.optim.Adam(discriminator.parameters(), lr=learning_rate) # Loss function definition lossfunc = Style2paintsLossCalculator() # Visualizer definition visualizer = Visualizer() iteration = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, collate_fn=collator, drop_last=True) progress_bar = tqdm(dataloader) for index, data in enumerate(progress_bar): iteration += 1 color, line = data y = model(line, color) loss = 0.01 * lossfunc.adversarial_disloss(discriminator, y.detach(), color) dis_opt.zero_grad() loss.backward() dis_opt.step() y = model(line, color) loss = 0.01 * lossfunc.adversarial_genloss(discriminator, y) loss += maeloss(y, color) loss += 0.001 * lossfunc.positive_enforcing_loss(y) gen_opt.zero_grad() loss.backward() gen_opt.step() if iteration % interval == 1: torch.save(model.state_dict(), f"{modeldir}/model_{iteration}.pt") with torch.no_grad(): y = model(l_valid, c_valid) c = c_valid.detach().cpu().numpy() l = l_valid.detach().cpu().numpy() y = y.detach().cpu().numpy() visualizer(l, c, y, outdir, iteration, validsize) print(f"iteration: {iteration} Loss: {loss.data}")