def optimize_adam(self, init_img, alpha=0.5, beta1=0.9, beta2=0.999, eps=1e-8, iterations=2000, save=50, filename='iter', str_contrast=False): chainer_adam = optimizers.Adam(alpha=alpha, beta1=beta1, beta2=beta2, eps=eps) chainer_adam.t = 0 state = { 'm': xp.zeros_like(init_img.data), 'v': xp.zeros_like(init_img.data) } out_img = Variable(xp.zeros_like(init_img.data), volatile=True) time_start = time.time() for epoch in range(iterations): chainer_adam.t += 1 loss = self.loss_total(init_img) loss.backward() loss.unchain_backward() # normalize gradient grad_l1_norm = xp.sum(xp.absolute(init_img.grad * init_img.grad)) init_img.grad /= grad_l1_norm if gpu_flag: chainer_adam.update_one_gpu(init_img, state) else: chainer_adam.update_one_cpu(init_img, state) init_img.zerograd() # save image every 'save' iteration if save != 0 and (epoch + 1) % save == 0: if self.preserve_color: init_img_lum = separate_lum_chr(init_img)[0] if gpu_flag: init_img_lum.to_gpu() out_img.copydata(init_img_lum + self.content_img_chr) else: out_img.copydata(init_img) save_image(out_img, filename + '_' + str(epoch + 1) + '.png', contrast=str_contrast) print( "Image Saved at Iteration %.0f, Time Used: %.4f, Total Loss: %.4f" % ((epoch + 1), (time.time() - time_start), loss.data))
class ArtNN: def __init__(self, neural_net, content_image, style_image, content_img_chr, alpha=50.0, beta=10000.0, keep_color=False): self.neural_net = neural_net self.preserve_color = keep_color # flag for preserving color self.alpha = alpha # weighting factors for content self.beta = beta # weighting factors for style self.content_img = Variable(xp.zeros_like(content_image.data)) self.style_img = Variable(xp.zeros_like(style_image.data)) self.content_img_chr = Variable(xp.zeros_like(content_image.data)) self.content_img.copydata(content_image) self.style_img.copydata(style_image) if keep_color: self.content_img_chr.copydata(content_img_chr) self.content_rep = self.neural_net(self.content_img)[-1:] self.style_rep = self.neural_net(self.style_img)[:-1] self.content_feat_map = self.feature_map(self.content_rep) self.style_feat_cor = self.feature_cor(self.style_rep) # extract feature map from a filtered image @staticmethod def feature_map(filtered_reps): feat_map_list = [] for rep in filtered_reps: num_channel = rep.shape[1] feat_map = F.reshape(rep, (num_channel, -1)) feat_map_list.append(feat_map) return tuple(feat_map_list) # compute feature correlations of a filtered image, # correlations are given by the Gram matrix # cf. equation (3) of the article def feature_cor(self, filtered_reps): gram_mat_list = [] feat_map_list = self.feature_map(filtered_reps) for feat_map in feat_map_list: gram_mat = F.matmul(feat_map, feat_map, transa=False, transb=True) gram_mat_list.append(gram_mat) return tuple(gram_mat_list) # content loss function # cf. equation (1) of the article def loss_content(self, gen_img_rep): feat_map_gen = self.feature_map(gen_img_rep) feat_loss = F.mean_squared_error(self.content_feat_map[0], feat_map_gen[0]) / 2.0 return feat_loss # style loss function # cf. equation (5) of the article def loss_style(self, gen_img_rep): feat_cor_gen = self.feature_cor(gen_img_rep) feat_loss = 0 for i in range(len(feat_cor_gen)): orig_shape = self.style_rep[i].shape feat_map_size = orig_shape[2] * orig_shape[3] # M_l layer_wt = 4.0 * feat_map_size**2.0 feat_loss += F.mean_squared_error(self.style_feat_cor[i], feat_cor_gen[i]) / layer_wt return feat_loss # total loss function # cf. equation (7) of the article def loss_total(self, input_img): input_img_rep = self.neural_net(input_img) content_loss = self.loss_content(input_img_rep[-1:]) style_loss = self.loss_style(input_img_rep[:-1]) total_loss = self.alpha * content_loss + self.beta * style_loss return total_loss def optimize_adam(self, init_img, alpha=0.5, beta1=0.9, beta2=0.999, eps=1e-8, iterations=2000, save=50, filename='iter', str_contrast=False): chainer_adam = optimizers.Adam(alpha=alpha, beta1=beta1, beta2=beta2, eps=eps) chainer_adam.t = 0 state = { 'm': xp.zeros_like(init_img.data), 'v': xp.zeros_like(init_img.data) } out_img = Variable(xp.zeros_like(init_img.data), volatile=True) time_start = time.time() for epoch in range(iterations): chainer_adam.t += 1 loss = self.loss_total(init_img) loss.backward() loss.unchain_backward() # normalize gradient grad_l1_norm = xp.sum(xp.absolute(init_img.grad * init_img.grad)) init_img.grad /= grad_l1_norm if gpu_flag: chainer_adam.update_one_gpu(init_img, state) else: chainer_adam.update_one_cpu(init_img, state) init_img.zerograd() # save image every 'save' iteration if save != 0 and (epoch + 1) % save == 0: if self.preserve_color: init_img_lum = separate_lum_chr(init_img)[0] if gpu_flag: init_img_lum.to_gpu() out_img.copydata(init_img_lum + self.content_img_chr) else: out_img.copydata(init_img) save_image(out_img, filename + '_' + str(epoch + 1) + '.png', contrast=str_contrast) print( "Image Saved at Iteration %.0f, Time Used: %.4f, Total Loss: %.4f" % ((epoch + 1), (time.time() - time_start), loss.data)) def optimize_rmsprop(self, init_img, lr=0.1, alpha=0.95, momentum=0.9, eps=1e-4, iterations=2000, save=50, filename='iter', str_contrast=False): chainer_rms = optimizers.RMSpropGraves(lr=lr, alpha=alpha, momentum=momentum, eps=eps) state = { 'n': xp.zeros_like(init_img.data), 'g': xp.zeros_like(init_img.data), 'delta': xp.zeros_like(init_img.data) } out_img = Variable(xp.zeros_like(init_img.data), volatile=True) time_start = time.time() for epoch in range(iterations): loss = self.loss_total(init_img) loss.backward() loss.unchain_backward() # normalize gradient grad_l1_norm = xp.sum(xp.absolute(init_img.grad * init_img.grad)) init_img.grad /= grad_l1_norm if gpu_flag: chainer_rms.update_one_gpu(init_img, state) else: chainer_rms.update_one_cpu(init_img, state) init_img.zerograd() # save image every 'save' iteration if save != 0 and (epoch + 1) % save == 0: if self.preserve_color: init_img_lum = separate_lum_chr(init_img)[0] if gpu_flag: init_img_lum.to_gpu() out_img.copydata(init_img_lum + self.content_img_chr) else: out_img.copydata(init_img) save_image(out_img, filename + '_' + str(epoch + 1) + '.png', contrast=str_contrast) print( "Image Saved at Iteration %.0f, Time Used: %.4f, Total Loss: %.4f" % ((epoch + 1), (time.time() - time_start), loss.data))