def forward(self):
        styles = self.affine(w) * self.weight_gain

        x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
        x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)

        return x
예제 #2
0
    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(
            x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)
        noise = None
        if self.use_noise and noise_mode == 'random':
            noise = torch.randn(
                [x.shape[0], 1, self.resolution, self.resolution],
                device=x.device) * self.noise_strength
            #noise += self.noise_const.expand_as(noise) * 0
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1)  # slightly faster
        x = modulated_conv2d(x=x,
                             weight=self.weight,
                             styles=styles,
                             noise=noise,
                             up=self.up,
                             padding=self.padding,
                             resample_filter=self.resample_filter,
                             flip_weight=flip_weight,
                             fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
예제 #3
0
 def forward(self, x, w, mask, fused_modconv=True):
     w_n, w_m, _ = w.shape
     styles = self.affine(w.view([w_n * w_m, -1])) * self.weight_gain
     x = x.repeat_interleave(w_m, 0)
     x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
     x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
     return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)
예제 #4
0
    def forward(self,
                x,
                latmask,
                w,
                noise_mode='random',
                fused_modconv=True,
                gain=1):
        # def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)

        noise = None
        if self.use_noise and noise_mode == 'random':
            # !!! custom
            sz = self.size if self.up == 2 and self.size is not None else x.shape[
                2:]
            noise = torch.randn([x.shape[0], 1, *sz],
                                device=x.device) * self.noise_strength
            # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength
            # !!! custom noise size
            noise_size = self.size if self.up == 2 and self.size is not None and self.resolution > 4 else x.shape[
                2:]
            noise = fix_size(noise.unsqueeze(0).unsqueeze(0),
                             noise_size,
                             scale_type=self.scale_type)[0][0]

        # print(x.shape, noise.shape, self.size, self.up)

        flip_weight = (self.up == 1)  # slightly faster
        x = modulated_conv2d(
            x=x,
            weight=self.weight,
            styles=styles,
            noise=noise,
            up=self.up,
            latmask=latmask,
            countHW=self.countHW,
            splitfine=self.splitfine,
            size=self.size,
            scale_type=self.scale_type,  # !!! custom
            padding=self.padding,
            resample_filter=self.resample_filter,
            flip_weight=flip_weight,
            fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
    def forward(self, x, gain=1):
        w = self.weight * self.weight_gain
        b = self.bias.to(x.dtype) if self.bias is not None else None
        flip_weight = (self.up == 1) # slightly faster
        x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
        return x
예제 #6
0
    def forward(self, x):
        w = self.weight.to(dtype=x.dtype) * self.weight_gain
        b = self.bias
        if b is not None:
            b = b.to(dtype=x.dtype)
            if self.bias_gain != 1:
                b = b * self.bias_gain

        if self.activation == 'linear' and b is not None:
            x = torch.addmm(b.unsqueeze(0), x, w.t())
        else:
            x = x.matmul(w.t())
            x = bias_act.bias_act(x, b, act=self.activation)
        return x
예제 #7
0
    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(
            x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)

        noise = None
        if self.cfg.use_noise and noise_mode == 'random':
            noise = torch.randn(
                [x.shape[0], 1, self.resolution, self.resolution],
                device=x.device) * self.noise_strength
        if self.cfg.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1)  # slightly faster

        if self.instance_norm:
            x = x / (x.std(dim=[2, 3], keepdim=True) + 1e-8
                     )  # [batch_size, c, h, w]

        if self.cfg.fmm.enabled:
            x = fmm_modulate_linear(x=x,
                                    weight=self.weight,
                                    styles=styles,
                                    noise=noise,
                                    activation=self.cfg.fmm.activation)
        else:
            x = modulated_conv2d(x=x,
                                 weight=self.weight,
                                 styles=styles,
                                 noise=noise,
                                 up=self.up,
                                 padding=self.padding,
                                 resample_filter=self.resample_filter,
                                 flip_weight=flip_weight,
                                 fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
예제 #8
0
    def forward(self, x, w, mask, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
        w_n, w_m, _ = w.shape
        styles = self.affine(w.view([w_n * w_m, -1]))

        noise = None
        if self.use_noise and noise_mode == 'random':
            noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1) # slightly faster
        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None

        x = x.repeat_interleave(w_m, 0)
        x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
                padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
        x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
        return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)