def __init__(self): super(NeuralStyleLoss, self).__init__() self.style_layers = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1' ] self.content_layers = ['conv3_2'] self.net = PerceptualNet(self.style_layers + self.content_layers) self.norm = ImageNetInputNorm()
def __init__(self) -> None: super(NeuralStyleLoss, self).__init__() self.style_layers = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1', ] self.style_weights = {'conv5_1': 1.0, 'conv4_1': 1.0} self.style_hists = {} self.content_layers = ['conv3_2'] self.hists_layers = ['conv5_1'] self.net = PerceptualNet(self.style_layers + self.content_layers, remove_unused_layers=False) tu.freeze(self.net) self.norm = ImageNetInputNorm()
def __init__(self, layers: List[str], rescale: bool = False, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, use_avg_pool: bool = True, remove_unused_layers: bool = True): super(PerceptualLoss, self).__init__() self.m = PerceptualNet(layers, use_avg_pool=use_avg_pool, remove_unused_layers=remove_unused_layers) self.norm = ImageNetInputNorm() self.rescale = rescale self.loss_fn = loss_fn
class NeuralStyleLoss(nn.Module): """ Style Transfer loss by Leon Gatys https://arxiv.org/abs/1508.06576 set the style and content before performing a forward pass. """ def __init__(self): super(NeuralStyleLoss, self).__init__() self.style_layers = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1' ] self.content_layers = ['conv3_2'] self.net = PerceptualNet(self.style_layers + self.content_layers) self.norm = ImageNetInputNorm() def get_style_content_(self, img, detach): _, activations = self.net(self.norm(img), detach=detach) style = { l: a for l, a in activations.items() if l in self.style_layers } content = { l: a for l, a in activations.items() if l in self.content_layers } return style, content def set_style(self, style_img, style_ratio, style_layers=None): """ Set the style. Args: style_img (3xHxW tensor): an image tensor style_ratio (float): a multiplier for the style loss to make it greater or smaller than the content loss style_layer (list of str, optional): the layers on which to compute the style, or `None` to keep them unchanged """ self.ratio = style_ratio if style_layers is not None: self.style_layers = style_layers self.net.set_keep_layers(names=self.style_layers + self.content_layers) with torch.no_grad(): activations = self.get_style_content_(style_img[None], detach=True)[0] grams = { layer_id: bgram(layer_data) for layer_id, layer_data in activations.items() } self.style_grams = grams def set_content(self, content_img, content_layers=None): """ Set the content. Args: content_img (3xHxW tensor): an image tensor content_layer (str, optional): the layer on which to compute the content representation, or `None` to keep it unchanged """ if content_layers is not None: self.content_layers = content_layers self.net.set_keep_layers(names=self.style_layers + self.content_layers) with torch.no_grad(): acts = self.get_style_content_(content_img[None], detach=True)[1] self.photo_activations = acts def forward(self, input_img): """ Actually compute the loss """ style_acts, content_acts = self.get_style_content_(input_img, detach=False) style_loss = 0 for j in style_acts: this_loss = F.l1_loss(bgram(style_acts[j]), self.style_grams[j], reduction='sum') style_loss += (1 / len(style_acts)) * this_loss content_loss = 0 for j in content_acts: content_loss += F.l1_loss(content_acts[j], self.photo_activations[j]) c_ratio = 1. / (1. + self.ratio) s_ratio = self.ratio / (1. + self.ratio) return c_ratio * content_loss + s_ratio * style_loss, { 'content_loss': content_loss.item(), 'style_loss': style_loss.item() }
class NeuralStyleLoss(nn.Module): """ Style Transfer loss by Leon Gatys https://arxiv.org/abs/1508.06576 set the style and content before performing a forward pass. """ style_hists: Dict[str, torch.Tensor] net: PerceptualNet def __init__(self) -> None: super(NeuralStyleLoss, self).__init__() self.style_layers = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1', ] self.style_weights = {'conv5_1': 1.0, 'conv4_1': 1.0} self.style_hists = {} self.content_layers = ['conv3_2'] self.hists_layers = ['conv5_1'] self.net = PerceptualNet(self.style_layers + self.content_layers, remove_unused_layers=False) tu.freeze(self.net) self.norm = ImageNetInputNorm() def get_style_content_(self, img: torch.Tensor, detach: bool) -> Dict[str, Dict[str, torch.Tensor]]: activations: Dict[str, torch.Tensor] _, activations = self.net(self.norm(img), detach=detach) grams = { layer_id: bgram(layer_data) for layer_id, layer_data in activations.items() if layer_id in self.style_layers } act_names = list(activations.keys()) for i in range(len(act_names)): for j in range(i, len(act_names)): comb_name = act_names[i] + ':' + act_names[j] if comb_name not in self.style_layers: continue small = activations[act_names[i]] big = activations[act_names[j]] if small.shape[-1] > big.shape[-1]: small, big = big, small small = F.interpolate(small, size=big.shape[-2:], mode='nearest') comb = torch.cat([big, small], dim=1) grams[comb_name] = bgram(comb) content = { layer: (a - a.mean((2, 3), keepdim=True)) / torch.sqrt(a.std((2, 3), keepdim=True) + 1e-8) for layer, a in activations.items() if layer in self.content_layers } hists = { layer: a for layer, a in activations.items() if layer in self.hists_layers } return {'grams': grams, 'content': content, 'hists': hists} def set_style(self, style_img: torch.Tensor, style_ratio: float, style_layers: Optional[List[str]] = None, style_weights: Optional[Dict[str, float]] = None) -> None: """ Set the style. Args: style_img (3xHxW tensor): an image tensor style_ratio (float): a multiplier for the style loss to make it greater or smaller than the content loss style_layer (list of str, optional): the layers on which to compute the style, or `None` to keep them unchanged """ self.ratio = torch.tensor(style_ratio) if style_layers is not None: self.style_layers = style_layers self.net.set_keep_layers(names=self.style_layers + self.content_layers) if style_weights is not None: self.style_weights = style_weights with torch.no_grad(): out = self.get_style_content_(style_img, detach=True) self.style_grams = out['grams'] self.style_hists = out['hists'] def set_content(self, content_img: torch.Tensor, content_layers: Optional[List[str]] = None) -> None: """ Set the content. Args: content_img (3xHxW tensor): an image tensor content_layer (str, optional): the layer on which to compute the content representation, or `None` to keep it unchanged """ if content_layers is not None: self.content_layers = content_layers self.net.set_keep_layers(names=self.style_layers + self.content_layers) with torch.no_grad(): acts = self.get_style_content_(content_img, detach=True)['content'] self.photo_activations = acts def forward( self, input_img: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]: """ Actually compute the loss """ out = self.get_style_content_(input_img, detach=False) style_grams, content_acts, hists = out['grams'], out['content'], out[ 'hists'] c_ratio = 1. - self.ratio.squeeze() s_ratio = self.ratio.squeeze() losses = {} style_loss = cast(torch.Tensor, 0.) avg = 1 / len(style_grams) for j in style_grams: this_loss = F.l1_loss(style_grams[j], self.style_grams[j], reduction='none').sum((1, 2)) w = self.style_weights.get(j, 1) this_loss = avg * w * this_loss losses['style:' + j] = (s_ratio * this_loss).mean().item() style_loss = style_loss + this_loss losses['style_loss'] = (s_ratio * style_loss).mean().item() content_loss = cast(torch.Tensor, 0.) avg = 1 / len(content_acts) for j in content_acts: this_loss = F.mse_loss(content_acts[j], self.photo_activations[j], reduction='none').mean((1, 2, 3)) content_loss = content_loss + avg * this_loss losses['content:' + j] = (c_ratio * this_loss).mean().item() losses['content_loss'] = (c_ratio * content_loss).mean().item() hists_loss = cast(torch.Tensor, 0.) losses['hists_loss'] = 0 if random.randint(0, 20) > 18: for layer in hists.keys(): hists_loss = hists_loss + hist_loss(hists[layer], self.style_hists[layer]) losses['hists_loss'] = hists_loss.mean().item() loss = (c_ratio * content_loss + s_ratio * (style_loss + hists_loss)).mean() losses['loss'] = loss.item() return loss, losses
def __init__(self, l, rescale=False, loss_fn=F.mse_loss): super(PerceptualLoss, self).__init__() self.m = PerceptualNet(l) self.norm = ImageNetInputNorm() self.rescale = rescale self.loss_fn = loss_fn