Example #1
0
    def __init__(self, config: Config):
        super(SimpleINR, self).__init__()

        layers = [
            nn.Linear(config.in_features,
                      config.layer_sizes[0],
                      bias=config.has_bias),
            create_activation(config.activation)
        ]

        for index in range(len(config.layer_sizes) - 1):
            layers.extend([
                nn.Linear(config.layer_sizes[index],
                          config.layer_sizes[index + 1],
                          bias=config.has_bias),
                create_activation(config.activation)
            ])

        layers.extend([
            nn.Linear(config.layer_sizes[-1],
                      config.out_features,
                      bias=config.has_bias),
        ])

        self.model = nn.Sequential(*layers)
Example #2
0
    def __init__(self, config: Config):
        super(FourierINR, self).__init__()

        layers = [
            nn.Linear(config.num_fourier_feats * 2,
                      config.layer_sizes[0],
                      bias=config.has_bias),
            create_activation(config.activation)
        ]

        for index in range(len(config.layer_sizes) - 1):
            transform = nn.Sequential(
                nn.Linear(config.layer_sizes[index],
                          config.layer_sizes[index + 1],
                          bias=config.has_bias),
                create_activation(config.activation))

            if config.residual.enabled:
                layers.append(LinearResidual(config.residual, transform))
            else:
                layers.append(transform)

        layers.append(
            nn.Linear(config.layer_sizes[-1],
                      config.out_features,
                      bias=config.has_bias))

        self.model = nn.Sequential(*layers)

        # Initializing the basis
        basis_matrix = config.scale * torch.randn(config.num_fourier_feats,
                                                  config.in_features)
        self.basis_matrix = nn.Parameter(basis_matrix,
                                         requires_grad=config.learnable_basis)
Example #3
0
    def init_model(self):
        layers = self.create_transform(
            self.config.hp.inr.coord_dim,
            self.config.hp.inr.layer_sizes[0],
            layer_type=self.config.hp.inr.
            coords_layer_type,  # First layer is of full control
            is_coord_layer=True)
        layers.append(INRProxy(self.create_sine(
            self.config.hp.inr.w0_initial)))

        for i in range(len(self.config.hp.inr.layer_sizes) - 1):
            layers.extend(
                self.create_transform(
                    self.config.hp.inr.layer_sizes[i],
                    self.config.hp.inr.layer_sizes[i + 1],
                    layer_type=self.config.hp.inr.hid_layer_type)
            )  # Middle layers are large so they are controlled via AdaIN
            layers.append(INRProxy(self.create_sine(self.config.hp.inr.w0)))

        layers.extend(
            self.create_transform(
                self.config.hp.inr.layer_sizes[-1],
                self.config.data.num_img_channels,
                layer_type=
                'linear'  # The last layer is small so let's also control it fully
            ))
        layers.append(
            INRProxy(create_activation(self.config.hp.inr.output_activation)))

        self.model = nn.Sequential(*layers)
Example #4
0
    def init_model(self):
        layers = self.create_transform(
            self.config.hp.inr.coord_dim,
            self.config.hp.inr.layer_sizes[0] // 2,
            layer_type=self.config.hp.inr.coords_layer_type,
            is_coord_layer=True)
        layers.append(INRProxy(create_activation('sines_cosines')))

        hid_layers = []

        for i in range(len(self.config.hp.inr.layer_sizes) - 1):
            if self.config.hp.inr.get('skip_coords') and i > 0:
                input_dim = self.config.hp.inr.layer_sizes[
                    i] + self.config.hp.inr.layer_sizes[0]
            else:
                input_dim = self.config.hp.inr.layer_sizes[i]

            curr_transform_layers = self.create_transform(
                input_dim,
                self.config.hp.inr.layer_sizes[i + 1],
                layer_type=self.config.hp.inr.hid_layer_type)
            curr_transform_layers.append(
                INRProxy(create_activation(self.config.hp.inr.activation)))

            if self.config.hp.inr.residual:
                hid_layers.append(
                    INRResidual(INRSequential(*curr_transform_layers)))
            else:
                hid_layers.extend(curr_transform_layers)

        if self.config.hp.inr.get('skip_coords'):
            layers.append(INRInputSkip(*hid_layers))
        else:
            layers.extend(hid_layers)

        layers.extend(
            self.create_transform(self.config.hp.inr.layer_sizes[-1],
                                  self.config.data.num_img_channels, 'linear'))
        layers.append(
            INRProxy(create_activation(self.config.hp.inr.output_activation)))

        self.model = nn.Sequential(*layers)
Example #5
0
    def __init__(self, in_features: int, out_features: int, config: Config,
                 is_first_layer: bool):
        super().__init__()

        if config.equalized_lr:
            layers = [EqualLinear(in_features, out_features)]
        else:
            layers = [nn.Linear(in_features, out_features)]

        if in_features == out_features and config.residual and not is_first_layer:
            self.residual = True
            self.main_branch_weight = nn.Parameter(
                torch.tensor(config.main_branch_weight))
        else:
            self.residual = False

        if config.has_bn:
            layers.append(nn.BatchNorm1d(out_features))

        layers.append(
            create_activation(config.activation, **config.activation_kwargs))

        self.transform = nn.Sequential(*layers)
Example #6
0
    def init_model(self):
        blocks = []
        if self.config.hp.inr.res_increase_scheme.enabled:
            res_configs = [
                self.create_res_config(i)
                for i in range(self.config.hp.inr.num_blocks)
            ]
        else:
            resolutions = self.generate_img_sizes(
                self.config.data.target_img_size)
            res_configs = [
                self.config.hp.inr.resolutions_params[resolutions[i]]
                for i in range(self.config.hp.inr.num_blocks)
            ]

        print('resolution:', [c.resolution for c in res_configs])
        print('dim:', [c.dim for c in res_configs])
        print('num_learnable_coord_feats:',
              [c.num_learnable_coord_feats for c in res_configs])
        print('to_rgb:', [c.to_rgb for c in res_configs])
        num_to_rgb_blocks = sum(c.to_rgb for c in res_configs)

        for i, res_config in enumerate(res_configs):
            # 1. Creating coord fourier feat embedders for each resolution
            coord_embedder = INRSequential(
                INRFourierFeats(res_config),
                INRProxy(create_activation('sines_cosines')))
            coord_feat_dim = coord_embedder[0].get_num_feats() * 2

            # 2. Main branch. First need does not need any wiring, but later layers use it.
            # A good thing is that we do not need skip-coords anymore.
            if i > 0:
                # Different-resolution blocks are wired together with the connector
                connector_layers = [
                    INRResConnector(
                        res_configs[i - 1].dim, coord_feat_dim, res_config.dim,
                        self.config.hp.inr.upsampling_mode,
                        **self.config.hp.inr.module_kwargs.se_factorized),
                ]
                if self.config.hp.inr.use_pixel_norm:
                    connector_layers.append(INRPixelNorm())
                connector_layers.append(
                    INRProxy(
                        create_activation(
                            self.config.hp.inr.activation,
                            **self.config.hp.inr.activation_kwargs)))
                connector = INRSequential(*connector_layers)
            else:
                connector = INRIdentity()

            transform_layers = []
            for j in range(res_config.n_layers):
                if i == 0 and j == 0:
                    input_size = coord_feat_dim  # Since we do not have previous feat dims
                elif self.config.hp.inr.skip_coords:
                    input_size = coord_feat_dim + res_config.dim
                else:
                    input_size = res_config.dim

                transform_layers.extend(
                    self.create_transform(
                        input_size,
                        res_config.dim,
                        layer_type=self.config.hp.inr.hid_layer_type))

                if self.config.hp.inr.use_pixel_norm:
                    transform_layers.append(INRPixelNorm())

                if self.config.hp.inr.use_noise:
                    transform_layers.append(INRNoiseInjection())

                transform_layers.append(
                    INRProxy(
                        create_activation(
                            self.config.hp.inr.activation,
                            **self.config.hp.inr.activation_kwargs)))

            if res_config.to_rgb or i == (self.config.hp.inr.num_blocks - 1):
                to_rgb_weight_std = self.compute_weight_std(
                    res_config.dim, is_coord_layer=False)
                to_rgb_bias_std = self.compute_bias_std(res_config.dim,
                                                        is_coord_layer=False)

                if self.config.hp.inr.additionaly_scale_to_rgb:
                    to_rgb_weight_std /= np.sqrt(num_to_rgb_blocks)
                    to_rgb_bias_std /= np.sqrt(num_to_rgb_blocks)

                to_rgb = INRToRGB(res_config.dim,
                                  self.config.hp.inr.to_rgb_activation,
                                  self.config.hp.inr.upsampling_mode,
                                  to_rgb_weight_std, to_rgb_bias_std)
            else:
                to_rgb = INRIdentity()

            if self.config.hp.inr.skip_coords:
                transform = INRCoordsSkip(*transform_layers,
                                          concat_to_the_first=i > 0)
            else:
                transform = INRSequential(*transform_layers)

            blocks.append(
                INRModuleDict({
                    'coord_embedder': coord_embedder,
                    'transform': transform,
                    'connector': connector,
                    'to_rgb': to_rgb,
                }))

        self.model = INRModuleDict({f'b_{i}': b for i, b in enumerate(blocks)})