# Load Style Target style_img = Image.open(opts.style_image).convert('RGB') with torch.no_grad(): style_img_tensor = transforms.Compose([ transforms.Resize(opts.image_size * 2), transforms.ToTensor(), tensor_normalizer() ])(style_img).unsqueeze(0) style_img_tensor = style_img_tensor.to(device) # Precalculate Gram Matrices of the Style Image # http://pytorch.org/docs/master/notes/autograd.html#volatile with torch.no_grad(): style_loss_features = loss_network(style_img_tensor) gram_style = [gram_matrix(y) for y in style_loss_features] print('# of VGG-19 layers which style loss use:', style_loss_features._fields) #for i in range(len(style_loss_features)): # tmp = style_loss_features[i].cpu().numpy() # print(i, np.mean(tmp), np.std(tmp)) #for i in range(len(style_loss_features)): # print(i, gram_style[i].numel(), gram_style[i].size()) # Train the Transformer torch.set_default_tensor_type('torch.FloatTensor') mse_loss = torch.nn.MSELoss() # l1_loss = torch.nn.L1Loss()
def train(transformer, loss_network, gram_style, gram_matrix, train_loader,\ content_weight, regularization, style_weights, log_interval,\ optimizer, device, steps, base_steps=0): transformer.train() count = 0 agg_content_loss = 0. agg_style_loss = 0. agg_reg_loss = 0. while True: for x, _ in train_loader: count += 1 optimizer.zero_grad() x = x.to(device) y = transformer(x) with torch.no_grad(): xc = x.detach() features_y = loss_network(y) features_xc = loss_network(xc) with torch.no_grad(): f_xc_c = features_xc[2].detach() content_loss = content_weight * mse_loss(features_y[2], f_xc_c) reg_loss = regularization * ( torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) style_loss = 0. for l, weight in enumerate(style_weights): gram_s = gram_style[l] gram_y = gram_matrix(features_y[l]) style_loss += float(weight) * mse_loss( gram_y, gram_s.expand_as(gram_y)) total_loss = content_loss + style_loss + reg_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss agg_style_loss += style_loss agg_reg_loss += reg_loss if count % log_interval == 0: mesg = "{} [{}/{}] content: {:.2f} style: {:.2f} reg: {:.2f} total: {:.6f}".format( time.ctime(), count, steps, agg_content_loss / log_interval, agg_style_loss / log_interval, agg_reg_loss / log_interval, (agg_content_loss + agg_style_loss + agg_reg_loss) / log_interval) print(mesg) agg_content_loss = 0. agg_style_loss = 0. agg_reg_loss = 0. agg_stable_loss = 0. transformer.eval() y = transformer(x) save_debug_image( x, y.detach(), "./fast-neural-style/debug_{}/{}.png".format( opts.style_name, base_steps + count)) transformer.train() if count >= steps: return
def train(steps, base_steps=0): transformer.train() count = 0 agg_content_loss = 0. agg_style_loss = 0. agg_reg_loss = 0. agg_stable_loss = 0. while True: for x, _ in train_loader: count += 1 optimizer.zero_grad() x = x.to(device) y = transformer(x) with torch.no_grad(): mask = torch.bernoulli( torch.ones_like(x, device=device, dtype=torch.float) * NOISE_P) noise = torch.normal( torch.zeros_like(x), torch.ones_like(x, device=device, dtype=torch.float) * NOISE_STD).clamp(-1, 1) # print((noise * mask).sum()) y_noise = transformer(x + noise * mask) with torch.no_grad(): xc = x.detach() features_xc = loss_network(xc) features_y = loss_network(y) with torch.no_grad(): f_xc_c = features_xc[2].detach() content_loss = CONTENT_WEIGHT * mse_loss(features_y[2], f_xc_c) reg_loss = REGULARIZATION * ( torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) style_loss = 0. for l, weight in enumerate(STYLE_WEIGHTS): gram_s = gram_style[l] gram_y = gram_matrix(features_y[l]) style_loss += float(weight) * mse_loss( gram_y, gram_s.expand_as(gram_y)) stability_loss = NOISE_WEIGHT * mse_loss(y_noise.view(-1), y.view(-1).detach()) total_loss = content_loss + style_loss + reg_loss + stability_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss agg_style_loss += style_loss agg_reg_loss += reg_loss agg_stable_loss += stability_loss if count % LOG_INTERVAL == 0: mesg = "{} [{}/{}] content: {:.2f} style: {:.2f} reg: {:.2f} stable: {:.2f} total: {:.6f}".format( time.ctime(), count, steps, agg_content_loss / LOG_INTERVAL, agg_style_loss / LOG_INTERVAL, agg_reg_loss / LOG_INTERVAL, agg_stable_loss / LOG_INTERVAL, (agg_content_loss + agg_style_loss + agg_reg_loss + agg_stable_loss) / LOG_INTERVAL) print(mesg) agg_content_loss = 0. agg_style_loss = 0. agg_reg_loss = 0. agg_stable_loss = 0. transformer.eval() y = transformer(x) save_debug_image(x, y.detach(), y_noise.detach(), "../debug/{}.png".format(base_steps + count)) transformer.train() if count >= steps: return