class Trial: def __init__(self, data_dir: str = './dataset', log_dir: str = './logs', device: str = "cuda:0", batch_size: int = 2, init_lr: float = 0.5, G_lr: float = 0.0004, D_lr: float = 0.0008, level: str = "O1", patch: bool = False, init_training_epoch: int = 10, train_epoch: int = 10, optim_type: str = "ADAM", pin_memory: bool = True, grad_set_to_none: bool = True): # self.config = config self.data_dir = data_dir self.dataset = Dataset(root=data_dir + "/Shinkai", style_transform=tr.transform, smooth_transform=tr.transform) self.pin_memory = pin_memory self.batch_size = batch_size self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=pin_memory) self.device = torch.device( device) if torch.cuda.is_available() else torch.device('cpu') self.G = Generator().to(self.device) self.patch = patch if self.patch: self.D = PatchDiscriminator().to(self.device) else: self.D = Discriminator().to(self.device) self.init_model_weights() self.optimizer_G = GANOptimizer(optim_type, self.G.parameters(), lr=G_lr, betas=(0.5, 0.999), amsgrad=False) self.optimizer_D = GANOptimizer(optim_type, self.D.parameters(), lr=D_lr, betas=(0.5, 0.999), amsgrad=True) self.loss = Loss(device=self.device).to(self.device) self.init_lr = init_lr self.G_lr = G_lr self.D_lr = D_lr self.grad_set_to_none = grad_set_to_none self.writer = tensorboard.SummaryWriter(log_dir=log_dir) self.init_train_epoch = init_training_epoch self.train_epoch = train_epoch self.init_time = None self.level = level if self.level != "O0" and device != "cpu": self.fp16 = True [self.G, self.D], [self.optimizer_G, self.optimizer_D ] = amp.initialize([self.G, self.D], [self.optimizer_G, self.optimizer_D], opt_level=self.level) else: self.fp16 = False def init_model_weights(self): self.G.apply(weights_init) self.D.apply(weights_init) @classmethod def from_config(cls): pass def init_train(self, con_weight: float = 1.0): test_img = self.get_test_image() meter = AverageMeter("Loss") self.writer.flush() lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=0.9999, steps_per_epoch=len(self.dataloader), epochs=self.init_train_epoch) for g in self.optimizer_G.param_groups: g['lr'] = self.init_lr for epoch in tqdm(range(self.init_train_epoch)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight content_loss = self.loss.content_loss(generator_output, train) * con_weight # content_loss = F.mse_loss(train, generator_output) * con_weight content_loss.backward() self.optimizer_G.step() lr_scheduler.step() meter.update(content_loss.detach()) self.writer.add_scalar(f"Loss : {self.init_time}", meter.sum.item(), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr # self.save_trial(self.init_train_epoch, "init") def eval_image(self, epoch: int, caption, img): """Feeds in one single image to process and save.""" self.G.eval() styled_test_img = tr.transform(img).unsqueeze(0).to(self.device) with torch.no_grad(): styled_test_img = self.G(styled_test_img) styled_test_img = styled_test_img.to('cpu').squeeze() self.write_image(styled_test_img, caption, epoch + 1) self.writer.flush() self.G.train() def write_image(self, image: torch.Tensor, img_caption: str = "sample_image", step: int = 0): image = torch.clip(tr.inv_norm(image).to(torch.float), 0, 1) # [-1, 1] -> [0, 1] image *= 255. # [0, 1] -> [0, 255] image = image.permute(1, 2, 0).to(dtype=torch.uint8) self.writer.add_image(img_caption, image, step, dataformats='HWC') self.writer.flush() def write_weights(self, epoch: int, write_D=True, write_G=True): if write_D: for name, weight in self.D.named_parameters(): if 'depthwise' in name or 'pointwise' in name: self.writer.add_histogram( f"Discriminator {name} {self.init_time}", weight, epoch) self.writer.add_histogram( f"Discriminator {name}.grad {self.init_time}", weight.grad, epoch) self.writer.flush() if write_G: for name, weight in self.G.named_parameters(): self.writer.add_histogram(f"Generator {name} {self.init_time}", weight, epoch) self.writer.add_histogram( f"Generator {name}.grad {self.init_time}", weight.grad, epoch) self.writer.flush() def train_1( self, adv_weight: float = 300., con_weight: float = 1.5, gra_weight: float = 3., col_weight: float = 10., ): test_img_dir = Path( self.data_dir).joinpath('test/test_photo256').resolve() test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) self.writer.add_image(f'test image {self.init_time}', np.asarray(test_img), dataformats='HWC') self.writer.flush() for epoch in tqdm(range(self.train_epoch)): for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad() style = style.to(self.device) smooth = smooth.to(self.device) train = train.to(self.device) # style image to discriminator(Not Gram Matrix Loss) style_loss_value = self.D(style).view(-1) generator_output = self.G(train) # generated image to discriminator real_output = self.D(generator_output.detach()).view(-1) # greyscale_output = D(transforms.functional.rgb_to_grayscale(train, num_output_channels=3)).view(-1) #greyscale adversarial loss gray_train = tr.inv_gray_transform(train) greyscale_output = self.D(gray_train).view(-1) smoothed_loss = self.D(smooth).view(-1) # smoothed image loss # loss_D_real = adversarial_loss(output, label) dis_adv_loss = adv_weight * ( torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean()) dis_gray_loss = torch.pow(greyscale_output, 2).mean() dis_edge_loss = torch.pow(smoothed_loss, 2).mean() discriminator_loss = dis_adv_loss + dis_gray_loss + dis_edge_loss discriminator_loss.backward() self.optimizer_D.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'{self.init_time} Discriminator losses', { 'adversarial loss': dis_adv_loss.item(), 'grayscale loss': dis_gray_loss.item(), 'edge loss': dis_edge_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() real_output = self.D(generator_output).view(-1) per_loss = self.loss.perceptual_loss( train, generator_output) # loss for G style_loss = self.loss.style_loss(generator_output, style) content_loss = self.loss.content_loss(generator_output, train) recon_loss = self.loss.reconstruction_loss( generator_output, train) tv_loss = self.loss.tv_loss(generator_output) ''' print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epoch, i, len(data_loader), loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))''' self.G.zero_grad() gen_adv_loss = adv_weight * torch.pow(real_output - 1, 2).mean() gen_con_loss = con_weight * content_loss gen_sty_loss = gra_weight * style_loss gen_rec_loss = col_weight * recon_loss gen_per_loss = per_loss gen_tv_loss = tv_loss generator_loss = gen_adv_loss + gen_con_loss + gen_sty_loss + gen_rec_loss + gen_per_loss generator_loss.backward() self.optimizer_G.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'generator losses {self.init_time}', { 'adversarial loss': gen_adv_loss.item(), 'content loss': gen_con_loss.item(), 'style loss': gen_sty_loss.item(), 'reconstruction loss': gen_rec_loss.item(), 'perceptual loss': gen_per_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() self.write_weights(epoch + 1) self.eval_image(epoch, f'{self.init_time} style img', test_img) def train_2(self, adv_weight: float = 1.0, threshold: float = 3., G_train_iter: int = 1, D_train_iter: int = 1 ): # if threshold is 0., set to half of adversarial loss test_img_dir = Path(self.data_dir).joinpath('test', 'test_photo256') test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) if self.init_time is None: self.init_time = datetime.datetime.now().strftime("%H:%M") self.writer.add_image(f'sample_image {self.init_time}', np.asarray(test_img), dataformats='HWC') self.writer.flush() perception_weight = 0. keep_constant = False for epoch in tqdm(range(self.train_epoch)): total_dis_loss = 0. for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad() train = train.to(self.device) style = style.to(self.device) # smooth = smooth.to(device) for _ in range(D_train_iter): style_loss_value = self.D(style).view(-1) generator_output = self.G(train) real_output = self.D(generator_output.detach()).view(-1) dis_adv_loss = adv_weight * \ (torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean()) total_dis_loss += dis_adv_loss.item() dis_adv_loss.backward() self.optimizer_D.step() self.G.zero_grad() for _ in range(G_train_iter): generator_output = self.G(train) real_output = self.D(generator_output).view(-1) per_loss = perception_weight * \ self.loss.perceptual_loss(train, generator_output) gen_adv_loss = adv_weight * torch.pow(real_output - 1, 2).mean() gen_loss = gen_adv_loss + per_loss gen_loss.backward() self.optimizer_G.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'generator losses {self.init_time}', { 'adversarial loss': dis_adv_loss.item(), 'Generator adversarial loss': gen_adv_loss.item(), 'perceptual loss': per_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() if total_dis_loss > threshold and not keep_constant: perception_weight += 0.05 else: keep_constant = True self.writer.add_scalar( f'total discriminator loss {self.init_time}', total_dis_loss, i + epoch * len(self.dataloader)) self.write_weights() self.G.eval() styled_test_img = tr.transform(test_img).unsqueeze(0).to( self.device) with torch.no_grad(): styled_test_img = self.G(styled_test_img) styled_test_img = styled_test_img.to('cpu').squeeze() self.write_image(styled_test_img, f'styled image {self.init_time}', epoch + 1) self.G.train() def __call__(self): self.init_train() self.train_1() def save_trial(self, epoch: int, train_type: str): save_dir = Path(f"{train_type}_{self.level}.pt") training_details = { "epoch": epoch, "gen": { "gen_state_dict": self.G.state_dict(), "optim_G_state_dict": self.optimizer_G.state_dict() }, "dis": { "dis_state_dict": self.D.state_dict(), "optim_D_state_dict": self.optimizer_D.state_dict() } } if self.fp16: training_details["amp"] = amp.state_dict() torch.save(training_details, save_dir.as_posix()) def load_trial(self, dir: Path): assert dir.is_file(), "No such directory" assert dir.suffix == ".pt", "Filetype not compatible" state_dicts = torch.load(dir.as_posix()) self.G.load_state_dict(state_dicts["gen"]["gen_state_dict"]) self.optimizer_G.load_state_dict( state_dicts["gen"]["optim_G_state_dict"]) self.D.load_state_dict(state_dicts["dis"]["dis_state_dict"]) self.optimizer_D.load_state_dict( state_dicts["dis"]["optim_D_state_dict"]) if self.fp16: amp.load_state_dict(state_dicts["amp"]) typer.echo("Loaded Weights") def Generator_NOGAN(self, epochs: int = 1, style_weight: float = 20., content_weight: float = 1.2, recon_weight: float = 10., tv_weight: float = 1e-6, loss: List[str] = ['content_loss']): """Training Generator in NOGAN manner (Feature Loss only).""" for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr test_img = self.get_test_image() max_lr = self.G_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) for epoch in tqdm(range(epochs)): total_losses = 0 meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) if 'style_loss' in loss: style = style.to(self.device) style_loss = self.loss.style_loss(generator_output, style) * style_weight else: style_loss = 0. if 'content_loss' in loss: content_loss = self.loss.content_loss( generator_output, train) * content_weight else: content_loss = 0. if 'recon_loss' in loss: recon_loss = self.loss.reconstruction_loss( generator_output, train) * recon_weight else: recon_loss = 0. if 'tv_loss' in loss: tv_loss = self.loss.tv_loss(generator_output) * tv_weight else: tv_loss = 0. total_loss = content_loss + tv_loss + recon_loss + style_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_G) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_G.step() lr_scheduler.step() total_losses += total_loss.detach() loss_dict = { 'content_loss': content_loss, 'style_loss': style_loss, 'recon_loss': recon_loss, 'tv_loss': tv_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) total_loss_arr = np.append(total_loss_arr, total_losses.item()) self.writer.add_scalars(f'{self.init_time} NOGAN generator losses', meter.as_dict('sum'), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break self.save_trial(epoch, f'G_NG_{self.init_time}') def Discriminator_NOGAN( self, epochs: int = 3, adv_weight: float = 1.0, edge_weight: float = 1.0, loss: List[str] = ['real_adv_loss', 'fake_adv_loss', 'gray_loss']): """https://discuss.pytorch.org/t/scheduling-batch-size-in-dataloader/46443/2""" for g in self.optimizer_D.param_groups: g['lr'] = self.D_lr max_lr = self.D_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_D, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) if self.init_time is None: self.init_time = datetime.datetime.now().strftime("%H:%M") for epoch in tqdm(range(epochs)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.D.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) style = style.to(self.device) generator_output = self.G(train) real_adv_loss = self.D(style).view(-1) fake_adv_loss = self.D(generator_output.detach()).view(-1) real_adv_loss = torch.pow(real_adv_loss - 1, 2).mean() * 1.7 * adv_weight fake_adv_loss = torch.pow(fake_adv_loss, 2).mean() * 1.7 * adv_weight gray_train = tr.inv_gray_transform(style) greyscale_output = self.D(gray_train).view(-1) gray_loss = torch.pow(greyscale_output, 2).mean() * 1.7 * adv_weight "According to AnimeGANv2 implementation, every loss is scaled by individual weights and then scaled with adv_weight" "https://github.com/TachibanaYoshino/AnimeGANv2/blob/5946b6afcca5fc28518b75a763c0f561ff5ce3d6/tools/ops.py#L217" total_loss = real_adv_loss + fake_adv_loss + gray_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_D) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_D.step() lr_scheduler.step() loss_dict = { 'real_adv_loss': real_adv_loss, 'fake_adv_loss': fake_adv_loss, 'gray_loss': gray_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) self.writer.add_scalars( f'{self.init_time} NOGAN discriminator loss', meter.as_dict('sum'), epoch) self.writer.flush() if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break def GAN_NOGAN( self, epochs: int = 1, GAN_G_lr: float = 0.00008, GAN_D_lr: float = 0.000016, D_loss: List[str] = [ "real_adv_loss", "fake_adv_loss", "gray_loss", "edge_loss" ], adv_weight: float = 300., edge_weight: float = 0.1, G_loss: List[str] = [ "adv_loss", "content_loss", "style_loss", "recon_loss" ], style_weight: float = 20., content_weight: float = 1.2, recon_weight: float = 10., tv_weight: float = 1e-6, ): test_img = self.get_test_image() dis_meter = LossMeters(*D_loss) gen_meter = LossMeters(*G_loss) for g in self.optimizer_G.param_groups: g['lr'] = GAN_G_lr for g in self.optimizer_D.param_groups: g['lr'] = GAN_D_lr update_duration = len(self.dataloader) // 20 for epoch in tqdm(range(epochs)): G_loss_arr = np.array([]) dis_meter.reset() count = 0 for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) style = style.to(self.device) smooth = smooth.to(self.device) generator_output = self.G(train) real_adv_loss = self.D(style).view(-1) fake_adv_loss = self.D(generator_output.detach()).view(-1) G_adv_loss = self.D(generator_output).view(-1) gray_train = tr.inv_gray_transform(style) grayscale_output = self.D(gray_train).view(-1) gray_smooth_data = tr.inv_gray_transform(smooth) smoothed_output = self.D(smooth).view(-1) real_adv_loss = torch.square(real_adv_loss - 1.).mean() * 1.7 * adv_weight fake_adv_loss = torch.square( fake_adv_loss).mean() * 1.7 * adv_weight gray_loss = torch.square( grayscale_output).mean() * 1.7 * adv_weight edge_loss = torch.square( smoothed_output).mean() * 1.0 * adv_weight total_D_loss = real_adv_loss + fake_adv_loss + gray_loss + edge_loss total_D_loss.backward() self.optimizer_D.step() D_loss_dict = { 'real_adv_loss': real_adv_loss, 'fake_adv_loss': fake_adv_loss, 'gray_loss': gray_loss, 'edge_loss': edge_loss } loss = list(D_loss_dict.values()) dis_meter.update(*loss) if i % update_duration == 0 and i != 0: self.writer.add_scalars(f'{self.init_time} NOGAN Dis loss', dis_meter.as_dict('val'), i + epoch * len(self.dataloader)) self.writer.flush() self.G.zero_grad(set_to_none=self.grad_set_to_none) G_adv_loss = torch.square(G_adv_loss - 1.).mean() * adv_weight if 'style_loss' in G_loss: style_loss = self.loss.style_loss(generator_output, style) * style_weight else: style_loss = 0. if 'content_loss' in G_loss: content_loss = self.loss.content_loss( generator_output, train) * content_weight else: content_loss = 0. if 'recon_loss' in G_loss: recon_loss = self.loss.reconstruction_loss( generator_output, train) * recon_weight else: recon_loss = 0. if 'tv_loss' in G_loss: tv_loss = self.loss.tv_loss(generator_output) * tv_weight else: tv_loss = 0. total_G_loss = G_adv_loss + content_loss + tv_loss + recon_loss + style_loss total_G_loss.backward() self.optimizer_G.step() G_loss_dict = { 'adv_loss': G_adv_loss, 'content_loss': content_loss, 'style_loss': style_loss, 'recon_loss': recon_loss, 'tv_loss': tv_loss } losses = [ G_loss_dict[loss_type].detach() for loss_type in G_loss ] gen_meter.update(*losses) if i % update_duration == 0 and i != 0: self.writer.add_scalars(f'{self.init_time} NOGAN Gen loss', gen_meter.as_dict('val'), i + epoch * len(self.dataloader)) self.writer.flush() G_loss_arr = np.append(G_loss_arr, G_adv_loss.item()) self.eval_image(i + epoch * len(self.dataloader), f'{self.init_time} reconstructed img', test_img) self.save_trial(epoch, f'GAN_NG_{self.init_time}') def get_test_image(self): """Get random test image.""" test_img_dir = Path(self.data_dir).joinpath('test/test_photo256') test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) self.init_time = datetime.datetime.now().strftime("%H:%M") self.writer.add_image(f'{self.init_time} sample_image', np.asarray(test_img), dataformats='HWC') self.writer.flush() return test_img
def run_fast_style_transfer(content_training_images, style_image_path, epochs, batch_size, content_weight=0.6, style_weight=0.4, total_variation_weight=1e-5): with tf.Session() as sess: K.set_session(sess) input_batch = tf.placeholder(tf.float32, shape=(None, height, width, 3), name="input_batch") init_image = TNET.get_TransformNet('transform_network', input_batch) loss = Loss(init_image, content_layers, style_layers) content_loss = loss.content_loss(input_batch) style_var = load_img(style_image_path) style_var = tf.Variable(style_var) style_loss = loss.style_loss(style_var) tv_loss = loss.tv_loss(init_image) total_loss = style_weight * style_loss + content_weight * content_loss + total_variation_weight * tv_loss transform_net = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='transform_network') opt = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.9, epsilon=1e-08).minimize( total_loss, var_list=[transform_net]) #sess.run(tf.variables_initializer(var_list=[input_batch])) sess.run(tf.global_variables_initializer()) # saver = tf.train.Saver() Tnet_saver = tf.train.Saver(transform_net) # loading the weights again because tf.global_variables_initializer() resets the weights loss.load_weights_to_vgg19( "vgg_weights/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5") # init_image.load_weights('0-transform_network.h5') dir_model = "weights/" + style_image.split('.')[0] + "_weights/" if not os.path.exists(dir_model): os.makedirs(dir_model) # Tnet_saver.restore(sess, dir_model+"model.ckpt") for i in range(epochs): avg_loss = 0 avg_cnt = 1 for j in range(0, int(len(content_training_images) / batch_size)): batch = load_batch(content_training_images[j:j + batch_size]) temp = sess.run([ total_loss, style_loss, content_loss, tv_loss, init_image, opt ], feed_dict={input_batch: batch}) print('epoch: ', i, 'batch: ', j, ' loss: ', temp[:4], 'avg loss: ', avg_loss) avg_loss = (avg_loss * (avg_cnt - 1) + temp[0]) / avg_cnt avg_cnt += 1 if j % 50 == 0: # and i%50==0: image = deprocess_img(temp[4][2], batch[2].shape[:-1]) cv2.imwrite(str(i) + '-' + str(j) + '-temp.jpg', image) if i == 0: image_ori = deprocess_img(batch[2], batch[2].shape[:-1]) cv2.imwrite( str(i) + '-' + str(j) + '-temp-orgi.jpg', image_ori) # if (i+1)%100==0: print('\n Data Saved ... ') Tnet_saver.save(sess, dir_model + "model.ckpt") sess.close()