Пример #1
0
    def __init__(self,
                 nf0,
                 out_channels,
                 input_resolution,
                 output_sidelength):
        super().__init__()

        norm = nn.BatchNorm2d

        num_down_unet = util.num_divisible_by_2(output_sidelength)
        num_downsampling = util.num_divisible_by_2(input_resolution) - num_down_unet

        self.net = nn.Sequential(
            DownsamplingNet([nf0 * (2 ** i) for i in range(num_downsampling)],
                            in_channels=3,
                            use_dropout=False,
                            norm=norm),
            Unet(in_channels=nf0 * (2 ** (num_downsampling-1)),
                 out_channels=out_channels,
                 nf0=nf0 * (2 ** (num_downsampling-1)),
                 use_dropout=False,
                 max_channels=8*nf0,
                 num_down=num_down_unet,
                 norm=norm)
        )
Пример #2
0
    def __init__(self, nf0, occnet_nf, frustrum_dims):
        super().__init__()
        self.device = get_device()

        self.occnet_nf = occnet_nf

        self.frustrum_depth = frustrum_dims[-1]
        depth_coords = torch.arange(-self.frustrum_depth // 2,
                                    self.frustrum_depth // 2)[None, None, :, None, None].float().cuda(self.device) / self.frustrum_depth
        self.depth_coords = depth_coords.repeat(1, 1, 1, frustrum_dims[0], frustrum_dims[0])

        self.occlusion_prep = nn.Sequential(
            Conv3dSame(nf0+1, self.occnet_nf, kernel_size=3, bias=False),
            nn.BatchNorm3d(self.occnet_nf),
            nn.ReLU(True),
        )

        num_down = min(util.num_divisible_by_2(self.frustrum_depth),
                       util.num_divisible_by_2(frustrum_dims[0]))

        self.occlusion_net = Unet3d(in_channels=self.occnet_nf,
                                    out_channels=self.occnet_nf,
                                    nf0=self.occnet_nf,
                                    num_down=num_down,
                                 max_channels=4*self.occnet_nf,
                                    outermost_linear=False)

        self.softmax_net = nn.Sequential(
            Conv3dSame(2*self.occnet_nf +1, 1, kernel_size=3, bias=True),
            nn.Softmax(dim=2),
        )
    def __init__(self, nf0, in_channels, input_resolution, img_sidelength):
        super().__init__()

        self.nf0 = nf0
        self.in_channels = in_channels
        self.input_resolution = input_resolution
        self.img_sidelength = img_sidelength

        self.num_down_unet = util.num_divisible_by_2(input_resolution)
        self.num_upsampling = util.num_divisible_by_2(
            img_sidelength) - self.num_down_unet

        self.build_net()
Пример #4
0
    def __init__(self,
                 nf0,
                 in_channels,
                 input_resolution,
                 img_sidelength):
        super().__init__()

        num_down_unet = util.num_divisible_by_2(input_resolution)
        num_upsampling = util.num_divisible_by_2(img_sidelength) - num_down_unet

        self.net = [
            Unet(in_channels=in_channels,
                 out_channels=3 if num_upsampling <= 0 else 4*nf0,
                 outermost_linear=True if num_upsampling <= 0 else False,
                 use_dropout=True,
                 dropout_prob=0.1,
                 nf0=nf0*(2**num_upsampling),
                 norm=nn.BatchNorm2d,
                 max_channels=8*nf0,
                 num_down=num_down_unet)
        ]

        if num_upsampling > 0:
            self.net += [
                UpsamplingNet(per_layer_out_ch=num_upsampling * [nf0],
                              in_channels=4 * nf0,
                              upsampling_mode='transpose',
                              use_dropout=True,
                              dropout_prob=0.1),
                Conv2dSame(nf0, out_channels=nf0 // 2, kernel_size=3, bias=False),
                nn.BatchNorm2d(nf0 // 2),
                nn.ReLU(True),
                Conv2dSame(nf0//2, 3, kernel_size=3)
            ]

        self.net += [nn.Tanh()]
        self.net = nn.Sequential(*self.net)
Пример #5
0
    def __init__(self,
                 img_sidelength,
                 lifting_img_dims,
                 frustrum_img_dims,
                 grid_dims,
                 num_grid_feats=64,
                 nf0=64,
                 use_occlusion_net=True):
        ''' Initializes the DeepVoxels model.

        :param img_sidelength: The sidelength of the input images (for instance 512)
        :param lifting_img_dims: The dimensions of the feature map to be lifted.
        :param frustrum_img_dims: The dimensions of the canonical view volume that DeepVoxels are resampled to.
        :param grid_dims: The dimensions of the deepvoxels grid.
        :param grid_dims: The number of featres in the outermost layer of U-Nets.
        :param use_occlusion_net: Whether to use the OcclusionNet or not.
        '''
        super().__init__()

        self.use_occlusion_net = use_occlusion_net
        self.grid_dims = grid_dims

        self.norm = nn.BatchNorm2d

        self.lifting_img_dims = lifting_img_dims
        self.frustrum_img_dims = frustrum_img_dims
        self.grid_dims = grid_dims

        # The frustrum depth is the number of voxels in the depth dimension of the canonical viewing volume.
        # It's calculated as the length of the diagonal of the DeepVoxels grid.
        self.frustrum_depth = int(np.ceil(1.5 * grid_dims[-1]))

        self.nf0 = nf0 # Number of features to use in the outermost layer of all U-Nets
        self.n_grid_feats = num_grid_feats  # Number of features in the DeepVoxels grid.
        self.occnet_nf = 4  # Number of features to use in the 3D unet of the occlusion subnetwork

        num_downs = util.num_divisible_by_2(img_sidelength) - 1

        # Feature extractor is an asymmetric UNet: Straight downsampling to 64x64, then a UNet with skip connections
        self.feature_extractor = nn.Sequential(
            # DownsamplingNet([self.nf0 * (2 ** i) for i in range(num_downs - 5)],
            #                 in_channels=3,
            #                 use_dropout=False,
            #                 norm=self.norm), 
            # AlexNetConv4()

            # Unet(in_channels=self.nf0 * (2 ** (num_downs - 6)),
            #      out_channels=self.n_grid_feats,
            #      nf0=self.nf0 * (2 ** (num_downs - 6)),
            #      use_dropout=False,
            #      max_channels=8*self.nf0,
            #      num_down=5,
            #      norm=self.norm)
            Conv2dPad(in_channels = 256, out_channels = 64, kernel_size = 1, padding_size = 4)
        )

        # Rendering net is an asymmetric UNet: UNet with skip connections and then straight upsampling
        self.rendering_net = nn.Sequential(
            Unet(in_channels=self.n_grid_feats,
                 out_channels=4 * self.nf0,
                 use_dropout=True,
                 dropout_prob=0.1,
                 nf0=self.nf0 * (2 ** (num_downs - 6)),
                 max_channels=8 * self.nf0,
                 num_down=5,
                 norm=self.norm),  # from 64 to 2 and back
            UpsamplingNet([4 * self.nf0, self.nf0] + max(0, num_downs - 7) * [self.nf0],
                          in_channels=4 * self.nf0,  # 4*self.nf0
                          use_dropout=True,
                          dropout_prob=0.1,
                          norm=self.norm,
                          upsampling_mode='transpose'),
            Conv2dSame(self.nf0, out_channels=self.nf0 // 2, kernel_size=3, bias=False),
            nn.BatchNorm2d(self.nf0 // 2),
            nn.ReLU(True),
            Conv2dSame(self.nf0 // 2, out_channels=3, kernel_size=3),
            nn.Tanh()
        )

        if self.use_occlusion_net:
            self.occlusion_net = OcclusionNet(nf0=self.n_grid_feats,
                                              occnet_nf=self.occnet_nf,
                                              frustrum_dims=[self.frustrum_img_dims[0], self.frustrum_img_dims[1],
                                                             self.frustrum_depth])
            print(self.occlusion_net)
        else:
            self.depth_collapse_net = nn.Sequential(
                Conv2dSame(self.n_grid_feats * self.frustrum_depth,
                           out_channels=self.nf0 * self.grid_dims[-1] // 2,
                           kernel_size=3,
                           bias=False),
                self.norm(self.nf0 * self.grid_dims[-1] // 2),
                nn.ReLU(True),
                Conv2dSame(self.nf0 * self.grid_dims[-1] // 2,
                           out_channels=self.nf0 * self.grid_dims[-1] // 8,
                           kernel_size=3,
                           bias=False),
                self.norm(self.nf0 * self.grid_dims[-1] // 8),
                nn.ReLU(True),
                Conv2dSame(self.nf0 * self.grid_dims[-1] // 8,
                           out_channels=self.nf0,
                           kernel_size=3,
                           bias=False),
                self.norm(self.nf0),
                nn.ReLU(True),
            )
            print(self.frustrum_collapse_net)

        # The deepvoxels grid is registered as a buffer - meaning, it is safed together with model parameters, but is
        # not trainable.
        self.register_buffer("deepvoxels",
                             torch.zeros(
                                 (1, self.n_grid_feats, self.grid_dims[0], self.grid_dims[1], self.grid_dims[2])))

        self.integration_net = IntegrationNet(self.n_grid_feats,
                                              use_dropout=True,
                                              coord_conv=True,
                                              per_feature=False,
                                              grid_dim=grid_dims[-1])

        self.inpainting_net = Unet3d(in_channels=self.n_grid_feats + 3,
                                     out_channels=self.n_grid_feats,
                                     num_down=2,
                                     nf0=self.n_grid_feats,
                                     max_channels=4 * self.n_grid_feats)

        print(100 * "*")
        print("inpainting_net")
        util.print_network(self.inpainting_net)
        print(self.inpainting_net)
        print("rendering net")
        util.print_network(self.rendering_net)
        print(self.rendering_net)
        print("feature extraction net")
        util.print_network(self.feature_extractor)
        print(self.feature_extractor)
        print(100 * "*")

        # Coordconv volumes
        coord_conv_volume = np.mgrid[-self.grid_dims[0] // 2:self.grid_dims[0] // 2,
                                     -self.grid_dims[1] // 2:self.grid_dims[1] // 2,
                                     -self.grid_dims[2] // 2:self.grid_dims[2] // 2]

        coord_conv_volume = np.stack(coord_conv_volume, axis=0).astype(np.float32)
        coord_conv_volume = coord_conv_volume / self.grid_dims[0]
        # self.coord_conv_volume = torch.Tensor(coord_conv_volume).float().cuda()[None, :, :, :, :]
        self.coord_conv_volume = torch.Tensor(coord_conv_volume).float()[None, :, :, :, :]