Example #1
0
    def __init__(self,
                 num_instances,
                 latent_dim,
                 tracing_steps,
                 has_params=False,
                 fit_single_srn=True,
                 use_unet_renderer=False,
                 freeze_networks=False):
        super().__init__()

        self.latent_dim = latent_dim
        self.has_params = has_params

        self.num_hidden_units_phi = 256
        self.phi_layers = 4  # includes the in and out layers
        self.rendering_layers = 5  # includes the in and out layers
        self.sphere_trace_steps = tracing_steps
        self.freeze_networks = freeze_networks
        self.fit_single_srn = fit_single_srn

        if self.fit_single_srn:  # Fit a single scene with a single SRN (no hypernetworks)
            self.phi = pytorch_prototyping.FCBlock(
                hidden_ch=self.num_hidden_units_phi,
                num_hidden_layers=self.phi_layers - 2,
                in_features=6,
                out_features=self.num_hidden_units_phi)

        self.ray_marcher = custom_layers.Raymarcher2(
            num_feature_channels=self.num_hidden_units_phi,
            raymarch_steps=self.sphere_trace_steps)
        self.phi_scene = PhiScene()

        if use_unet_renderer:
            self.pixel_generator = custom_layers.DeepvoxelsRenderer(
                nf0=32,
                in_channels=self.num_hidden_units_phi,
                input_resolution=128,
                img_sidelength=128)
        else:
            self.pixel_generator = pytorch_prototyping.FCBlock(
                hidden_ch=self.num_hidden_units_phi,
                num_hidden_layers=self.rendering_layers - 1,
                in_features=self.num_hidden_units_phi,
                out_features=3,
                outermost_linear=True)
        if self.freeze_networks:
            all_network_params = (list(self.pixel_generator.parameters()) +
                                  list(self.ray_marcher.parameters()) +
                                  list(self.hyper_phi.parameters()))
            for param in all_network_params:
                param.requires_grad = False

        # Losses
        self.l2_loss = nn.MSELoss(reduction="mean")

        # List of logs
        self.logs = list()
        print(self)
        print("Number of parameters:")
        util.print_network(self)
    def __init__(self,
                 num_instances,
                 latent_dim,
                 tracing_steps,
                 has_params=False,
                 fit_single_srn=False,
                 use_unet_renderer=False,
                 freeze_networks=False):
        super().__init__()

        self.latent_dim = latent_dim
        self.has_params = has_params

        self.num_hidden_units_phi = 256
        self.phi_layers = 4  # includes the in and out layers
        self.rendering_layers = 5  # includes the in and out layers
        self.sphere_trace_steps = tracing_steps
        self.freeze_networks = freeze_networks
        self.fit_single_srn = fit_single_srn

        if self.fit_single_srn:  # Fit a single scene with a single SRN (no hypernetworks)
            self.phi = pytorch_prototyping.FCBlock(hidden_ch=self.num_hidden_units_phi,
                                                   num_hidden_layers=self.phi_layers - 2,
                                                   in_features=3,
                                                   out_features=self.num_hidden_units_phi)
        else:
            # Auto-decoder: each scene instance gets its own code vector z
            self.latent_codes = nn.Embedding(num_instances, latent_dim).cuda()
            nn.init.normal_(self.latent_codes.weight, mean=0, std=0.01)

            self.hyper_phi = hyperlayers.HyperFC(hyper_in_ch=self.latent_dim,
                                                 hyper_num_hidden_layers=1,
                                                 hyper_hidden_ch=self.latent_dim,
                                                 hidden_ch=self.num_hidden_units_phi,
                                                 num_hidden_layers=self.phi_layers - 2,
                                                 in_ch=3,
                                                 out_ch=self.num_hidden_units_phi)

        self.ray_marcher = custom_layers.Raymarcher(num_feature_channels=self.num_hidden_units_phi,
                                                    raymarch_steps=self.sphere_trace_steps)

        if use_unet_renderer:
            self.pixel_generator = custom_layers.DeepvoxelsRenderer(nf0=32, in_channels=self.num_hidden_units_phi,
                                                                    input_resolution=128, img_sidelength=128)
        else:
            self.pixel_generator = pytorch_prototyping.FCBlock(hidden_ch=self.num_hidden_units_phi,
                                                               num_hidden_layers=self.rendering_layers - 1,
                                                               in_features=self.num_hidden_units_phi,
                                                               out_features=3,
                                                               outermost_linear=True)

        if self.freeze_networks:
            all_network_params = (self.pixel_generator.parameters()
                                  + self.ray_marcher.parameters()
                                  + self.hyper_phi.parameters())
            for param in all_network_params:
                param.requires_grad = False

        # Losses
        self.l2_loss = nn.MSELoss(reduction="mean")

        # List of logs
        self.logs = list()

        print(self)
        print("Number of parameters:")
        util.print_network(self)