Пример #1
0
    def forward(self, x_in, pose_info, z=None):
        if z is None:
            z = self.generate_latent_variable(x_in)
        unet_skips = []
        if self.transition_step != 0:
            old_down = self.from_rgb_old(x_in)
            new_down = self.from_rgb_new(x_in)
            new_down = self.new_down(new_down)
            unet_skips.append(new_down)
            new_down = self.downsampling(new_down)
            # MZ: gets shape at this stage of the UNet
            x = get_transition_value(old_down, new_down, self.transition_value)
        else:
            x = self.from_rgb_new(x_in)

        for block in self.core_blocks_down[:-1]:
            x = block(x)
            unet_skips.append(x)
        x = self.core_blocks_down[-1](x)
        pose_channels = generate_pose_channel_images(4, self.current_imsize,
                                                     x_in.device, pose_info,
                                                     x_in.dtype)
        x = torch.cat((x, pose_channels[0], z), dim=1)
        x = self.core_blocks_up[0](x)

        for idx, block in enumerate(self.core_blocks_up[1:]):
            skip_x = unet_skips[-idx - 1]
            assert skip_x.shape == x.shape, "IDX: {}, skip_x: {}, x: {}".format(
                idx, skip_x.shape, x.shape)
            x = torch.cat((x, skip_x, pose_channels[idx + 1]), dim=1)
            x = block(x)

        if self.transition_step == 0:
            x = self.to_rgb_new(x)
            return x
        x_old = self.to_rgb_old(x)
        x = torch.cat((x, unet_skips[0], pose_channels[-1]), dim=1)
        x_new = self.new_up(x)
        x_new = self.to_rgb_new(x_new)
        # MZ: Really confused. It returns a shape and not the gen. image?
        x = get_transition_value(x_old, x_new, self.transition_value)
        return x
Пример #2
0
    def forward(self, x, condition, pose_info):
        pose_channels = generate_pose_channel_images(4, self.current_imsize,
                                                     x.device, pose_info,
                                                     x.dtype)
        x = torch.cat((x, condition), dim=1)
        x_old = self.from_rgb_old(x)
        x_new = self.from_rgb_new(x)
        if self.current_imsize != 4:
            x_new = torch.cat((x_new, pose_channels[-1]), dim=1)
        x_new = self.new_block(x_new)
        x = get_transition_value(x_old, x_new, self.transition_value)
        idx = 1 if self.current_imsize == 4 else 2
        for block in self.core_model.children():
            x = torch.cat((x, pose_channels[-idx]), dim=1)
            idx += 1
            x = block(x)

        x = x.view(x.shape[0], -1)
        x = self.output_layer(x)
        return x