def img_generation(self, x, norm_weights, encoded_label, encoded_label_raw=None): # main branch convolution layers for i in range(self.n_downsample_G, -1, -1): conv_weight = None norm_weight = norm_weights[i] if ( self.adap_spade and i < self.n_adaptive_layers) else None # if require loss for raw image if self.add_raw_loss and i < self.n_sc_layers: if i == self.n_sc_layers - 1: x_raw = x x_raw = getattr(self, 'up_' + str(i))(x_raw, encoded_label_raw[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x_raw = self.up(x_raw) x = getattr(self, 'up_' + str(i))(x, encoded_label[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x = self.up(x) # raw synthesized image x = self.conv_img(actvn(x)) fake_raw_img = torch.tanh(x) x_raw = None if not self.add_raw_loss else torch.tanh( self.conv_img(actvn(x_raw))) return fake_raw_img, x_raw
def forward(self, input, weights=None): if input is None: return None if self.first_layer_free: output = [actvn(batch_conv(input, weights[0]))] weights = weights[1:] else: output = [self.conv_first(input)] for i in range(self.n_downsample_S): if i >= self.params_free_layers or self.decode: conv = getattr(self, 'down_%d' % i)(output[-1]) else: conv = actvn(batch_conv(output[-1], weights[i], stride=2)) output.append(conv) if not self.decode: return output if not self.unet: output = [output[-1]] for i in reversed(range(self.n_downsample_S)): input_i = output[-1] if self.unet and i != self.n_downsample_S-1: input_i = torch.cat([input_i, output[i+1]], dim=1) if i >= self.params_free_layers: conv = getattr(self, 'up_%d' % i)(input_i) else: input_i = nn.Upsample(scale_factor=2)(input_i) conv = actvn(batch_conv(input_i, weights[i]))#, stride=0.5)) output.append(conv) if self.unet: output = output[self.n_downsample_S:] return output[::-1]
def forward_face(self, label, label_refs, img_refs, img_coarse): x, encoded_label, _, norm_weights, _, _, _, _, _ = self.weight_generation(img_refs, label_refs, label, img_coarse=img_coarse) for i in range(self.n_downsample_G, -1, -1): norm_weight = norm_weights[i] if (self.adap_spade and i < self.n_adaptive_layers) else None x = getattr(self, 'up_'+str(i))(x, encoded_label[i], norm_weights=norm_weight) if i != 0: x = self.up(x) x = self.conv_img(actvn(x)) img_final = torch.tanh(x) return img_final
def forward(self, label, label_refs, img_refs, prev=[None, None], t=0, img_coarse=None): ### for face refinement if img_coarse is not None: return self.forward_face(label, label_refs, img_refs, img_coarse) ### SPADE weight generation x, encoded_label, conv_weights, norm_weights, mu, logvar, atn, atn_vis, ref_idx \ = self.weight_generation(img_refs, label_refs, label, t=t) ### flow estimation flow, flow_mask, img_warp, ds_ref = self.flow_generation(label, label_refs, img_refs, prev, atn, ref_idx) flow_mask_ref, flow_mask_prev = flow_mask img_ref_warp, img_prev_warp = img_warp if self.add_raw_output_loss: encoded_label_raw = [encoded_label[i] for i in range(self.n_sc_layers)] encoded_label = self.SPADE_combine(encoded_label, ds_ref) ### main branch convolution layers for i in range(self.n_downsample_G, -1, -1): conv_weight = conv_weights[i] if (self.adap_conv and i < self.n_adaptive_layers) else None norm_weight = norm_weights[i] if (self.adap_spade and i < self.n_adaptive_layers) else None if self.add_raw_output_loss and i < self.n_sc_layers: if i == self.n_sc_layers - 1: x_raw = x x_raw = getattr(self, 'up_'+str(i))(x_raw, encoded_label_raw[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x_raw = self.up(x_raw) x = getattr(self, 'up_'+str(i))(x, encoded_label[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x = self.up(x) ### raw synthesized image x = self.conv_img(actvn(x)) img_raw = torch.tanh(x) ### combine with reference / previous images if not self.spade_combine: ### combine raw result with reference image if self.warp_ref: img_final = img_raw * flow_mask_ref + img_ref_warp * (1 - flow_mask_ref) else: img_final = img_raw if not self.warp_prev: img_raw = None ### combine generated frame with previous frame if self.warp_prev and prev[0] is not None: img_final = img_final * flow_mask_prev + img_prev_warp * (1 - flow_mask_prev) else: img_final = img_raw img_raw = None if not self.add_raw_output_loss else torch.tanh(self.conv_img(actvn(x_raw))) return img_final, flow, flow_mask, img_raw, img_warp, mu, logvar, atn_vis, ref_idx
def forward(self, label, label_refs, img_refs, prev=[None, None], t=0, img_coarse=None): ### for face refinement if img_coarse is not None: return self.forward_face(label, label_refs, img_refs, img_coarse) ### SPADE weight generation x, encoded_label, conv_weights, norm_weights, mu, logvar, atn, ref_idx \ = self.weight_generation(img_refs, label_refs, label, t=t) ### flow estimation has_prev = prev[0] is not None label_ref, img_ref = self.pick_ref([label_refs, img_refs], ref_idx) label_prev, img_prev = prev flow, weight, img_warp, ds_ref = self.flow_generation( label, label_ref, img_ref, label_prev, img_prev, has_prev) weight_ref, weight_prev = weight img_ref_warp, img_prev_warp = img_warp encoded_label = self.SPADE_combine(encoded_label, ds_ref) ### main branch convolution layers for i in range(self.n_downsample_G, -1, -1): conv_weight = conv_weights[i] if ( self.adap_conv and i < self.n_adaptive_layers) else None norm_weight = norm_weights[i] if ( self.adap_spade and i < self.n_adaptive_layers) else None x = getattr(self, 'up_' + str(i))(x, encoded_label[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x = self.up(x) ### raw synthesized image x = self.conv_img(actvn(x)) img_raw = torch.tanh(x) ### combine with reference / previous images if not self.spade_combine: ### combine raw result with reference image if self.warp_ref: img_final = img_raw * weight_ref + img_ref_warp * (1 - weight_ref) else: img_final = img_raw if not self.warp_prev: img_raw = None ### combine generated frame with previous frame if self.warp_prev and has_prev: img_final = img_final * weight_prev + img_prev_warp * ( 1 - weight_prev) else: img_final = img_raw img_raw = None return img_final, flow, weight, img_raw, img_warp, mu, logvar, atn, ref_idx