def __fit_one(self, link, content4_2, style3_2,style4_2):
        xp = self.xp
        link.zerograds()
        layer3_2,layer4_2 = self.model(link.x)
        if self.keep_color:
            #trans_layers = self.model(util.gray(link.x))
            print "don't keep color!"
        loss_info = []
        loss = Variable(xp.zeros((), dtype=np.float32))
        #layer = layers[name]
        content_loss = self.content_weight * F.mean_squared_error(layer4_2, Variable(content4_2))
        loss_info.append(('content_', float(content_loss.data)))
        loss += content_loss

        style_patch, style_patch_norm =  style3_2
        near,size,size2 = util.nearest_neighbor_patch(layer3_2, style_patch, style_patch_norm)
        style_loss = self.style_weight * (F.sum(F.square(layer3_2))*size2/size-2*F.sum(near)/size) 
        loss_info.append(('style_', float(style_loss.data)))
        loss+=style_loss
        
        style_patch, style_patch_norm =  style4_2
        near,size,size2 = util.nearest_neighbor_patch(layer4_2, style_patch, style_patch_norm)
        style_loss = self.style_weight *1.5* (F.sum(F.square(layer4_2))*size2/size-2*F.sum(near)/size) 
        loss_info.append(('style_', float(style_loss.data)))
        loss+= style_loss

        tv_loss = self.tv_weight * util.total_variation(link.x)
        loss_info.append(('tv', float(tv_loss.data)))
        loss+=tv_loss
        loss.backward()
        self.optimizer.update()
        return loss_info
 def __fit_one(self, link, content_layers, style_patches):
     xp = self.xp
     link.zerograds()
     layers = self.model(link.x)
     if self.keep_color:
         trans_layers = self.model(util.gray(link.x))
     else:
         trans_layers = layers
     loss_info = []
     loss = Variable(xp.zeros((), dtype=np.float32))
     for name, content_layer in content_layers:
         layer = layers[name]
         content_loss = self.content_weight * F.mean_squared_error(
             layer, content_layer)
         loss_info.append(('content_' + name, float(content_loss.data)))
         loss += content_loss
     for name, style_patch, style_patch_norm in style_patches:
         patch = util.patch(trans_layers[name])
         style_loss = self.style_weight * F.mean_squared_error(
             patch,
             util.nearest_neighbor_patch(patch, style_patch,
                                         style_patch_norm))
         loss_info.append(('style_' + name, float(style_loss.data)))
         loss += style_loss
     tv_loss = self.tv_weight * util.total_variation(link.x)
     loss_info.append(('tv', float(tv_loss.data)))
     loss += tv_loss
     loss.backward()
     self.optimizer.update()
     return loss_info