def img_generation(self,
                       x,
                       norm_weights,
                       encoded_label,
                       encoded_label_raw=None):
        # main branch convolution layers
        for i in range(self.n_downsample_G, -1, -1):
            conv_weight = None
            norm_weight = norm_weights[i] if (
                self.adap_spade and i < self.n_adaptive_layers) else None
            # if require loss for raw image
            if self.add_raw_loss and i < self.n_sc_layers:
                if i == self.n_sc_layers - 1: x_raw = x
                x_raw = getattr(self, 'up_' + str(i))(x_raw,
                                                      encoded_label_raw[i],
                                                      conv_weights=conv_weight,
                                                      norm_weights=norm_weight)
                if i != 0: x_raw = self.up(x_raw)
            x = getattr(self, 'up_' + str(i))(x,
                                              encoded_label[i],
                                              conv_weights=conv_weight,
                                              norm_weights=norm_weight)
            if i != 0: x = self.up(x)

        # raw synthesized image
        x = self.conv_img(actvn(x))
        fake_raw_img = torch.tanh(x)
        x_raw = None if not self.add_raw_loss else torch.tanh(
            self.conv_img(actvn(x_raw)))

        return fake_raw_img, x_raw
Ejemplo n.º 2
0
    def forward(self, input, weights=None):
        if input is None: return None
        if self.first_layer_free:
            output = [actvn(batch_conv(input, weights[0]))]
            weights = weights[1:]
        else:
            output = [self.conv_first(input)]
        for i in range(self.n_downsample_S):
            if i >= self.params_free_layers or self.decode:                
                conv = getattr(self, 'down_%d' % i)(output[-1])
            else:                
                conv = actvn(batch_conv(output[-1], weights[i], stride=2))
            output.append(conv)

        if not self.decode:
            return output

        if not self.unet:
            output = [output[-1]]
        for i in reversed(range(self.n_downsample_S)):
            input_i = output[-1]
            if self.unet and i != self.n_downsample_S-1:
                input_i = torch.cat([input_i, output[i+1]], dim=1)
            if i >= self.params_free_layers:                
                conv = getattr(self, 'up_%d' % i)(input_i)
            else:
                input_i = nn.Upsample(scale_factor=2)(input_i)
                conv = actvn(batch_conv(input_i, weights[i]))#, stride=0.5))
            output.append(conv)
        if self.unet:
            output = output[self.n_downsample_S:]   
        return output[::-1]
Ejemplo n.º 3
0
 def forward_face(self, label, label_refs, img_refs, img_coarse):                
     x, encoded_label, _, norm_weights, _, _, _, _, _ = self.weight_generation(img_refs, label_refs, label, img_coarse=img_coarse)
     
     for i in range(self.n_downsample_G, -1, -1):            
         norm_weight = norm_weights[i] if (self.adap_spade and i < self.n_adaptive_layers) else None                  
         x = getattr(self, 'up_'+str(i))(x, encoded_label[i], norm_weights=norm_weight)            
         if i != 0: x = self.up(x)                
     
     x = self.conv_img(actvn(x))
     img_final = torch.tanh(x)        
     return img_final
Ejemplo n.º 4
0
    def forward(self, label, label_refs, img_refs, prev=[None, None], t=0, img_coarse=None):
        ### for face refinement
        if img_coarse is not None:
            return self.forward_face(label, label_refs, img_refs, img_coarse)        

        ### SPADE weight generation
        x, encoded_label, conv_weights, norm_weights, mu, logvar, atn, atn_vis, ref_idx \
            = self.weight_generation(img_refs, label_refs, label, t=t)        

        ### flow estimation         
        flow, flow_mask, img_warp, ds_ref = self.flow_generation(label, label_refs, img_refs, prev, atn, ref_idx)

        flow_mask_ref, flow_mask_prev = flow_mask
        img_ref_warp, img_prev_warp = img_warp           
        if self.add_raw_output_loss: encoded_label_raw = [encoded_label[i] for i in range(self.n_sc_layers)]
        encoded_label = self.SPADE_combine(encoded_label, ds_ref)
        
        ### main branch convolution layers
        for i in range(self.n_downsample_G, -1, -1):            
            conv_weight = conv_weights[i] if (self.adap_conv and i < self.n_adaptive_layers) else None
            norm_weight = norm_weights[i] if (self.adap_spade and i < self.n_adaptive_layers) else None                  
            if self.add_raw_output_loss and i < self.n_sc_layers:
                if i == self.n_sc_layers - 1: x_raw = x
                x_raw = getattr(self, 'up_'+str(i))(x_raw, encoded_label_raw[i], conv_weights=conv_weight, norm_weights=norm_weight)    
                if i != 0: x_raw = self.up(x_raw)            
            x = getattr(self, 'up_'+str(i))(x, encoded_label[i], conv_weights=conv_weight, norm_weights=norm_weight)
            if i != 0: x = self.up(x)

        ### raw synthesized image
        x = self.conv_img(actvn(x))
        img_raw = torch.tanh(x)        

        ### combine with reference / previous images
        if not self.spade_combine:
            ### combine raw result with reference image
            if self.warp_ref:
                img_final = img_raw * flow_mask_ref + img_ref_warp * (1 - flow_mask_ref)        
            else:
                img_final = img_raw
                if not self.warp_prev: img_raw = None

            ### combine generated frame with previous frame
            if self.warp_prev and prev[0] is not None:
                img_final = img_final * flow_mask_prev + img_prev_warp * (1 - flow_mask_prev)        
        else:
            img_final = img_raw
            img_raw = None if not self.add_raw_output_loss else torch.tanh(self.conv_img(actvn(x_raw)))
                
        return img_final, flow, flow_mask, img_raw, img_warp, mu, logvar, atn_vis, ref_idx
Ejemplo n.º 5
0
    def forward(self,
                label,
                label_refs,
                img_refs,
                prev=[None, None],
                t=0,
                img_coarse=None):
        ### for face refinement
        if img_coarse is not None:
            return self.forward_face(label, label_refs, img_refs, img_coarse)

        ### SPADE weight generation
        x, encoded_label, conv_weights, norm_weights, mu, logvar, atn, ref_idx \
            = self.weight_generation(img_refs, label_refs, label, t=t)

        ### flow estimation
        has_prev = prev[0] is not None
        label_ref, img_ref = self.pick_ref([label_refs, img_refs], ref_idx)
        label_prev, img_prev = prev
        flow, weight, img_warp, ds_ref = self.flow_generation(
            label, label_ref, img_ref, label_prev, img_prev, has_prev)

        weight_ref, weight_prev = weight
        img_ref_warp, img_prev_warp = img_warp
        encoded_label = self.SPADE_combine(encoded_label, ds_ref)

        ### main branch convolution layers
        for i in range(self.n_downsample_G, -1, -1):
            conv_weight = conv_weights[i] if (
                self.adap_conv and i < self.n_adaptive_layers) else None
            norm_weight = norm_weights[i] if (
                self.adap_spade and i < self.n_adaptive_layers) else None
            x = getattr(self, 'up_' + str(i))(x,
                                              encoded_label[i],
                                              conv_weights=conv_weight,
                                              norm_weights=norm_weight)
            if i != 0: x = self.up(x)

        ### raw synthesized image
        x = self.conv_img(actvn(x))
        img_raw = torch.tanh(x)

        ### combine with reference / previous images
        if not self.spade_combine:
            ### combine raw result with reference image
            if self.warp_ref:
                img_final = img_raw * weight_ref + img_ref_warp * (1 -
                                                                   weight_ref)
            else:
                img_final = img_raw
                if not self.warp_prev: img_raw = None

            ### combine generated frame with previous frame
            if self.warp_prev and has_prev:
                img_final = img_final * weight_prev + img_prev_warp * (
                    1 - weight_prev)
        else:
            img_final = img_raw
            img_raw = None

        return img_final, flow, weight, img_raw, img_warp, mu, logvar, atn, ref_idx