def forward(self): styles = self.affine(w) * self.weight_gain x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape( x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn( [x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength #noise += self.noise_const.expand_as(noise) * 0 if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x, w, mask, fused_modconv=True): w_n, w_m, _ = w.shape styles = self.affine(w.view([w_n * w_m, -1])) * self.weight_gain x = x.repeat_interleave(w_m, 0) x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)
def forward(self, x, latmask, w, noise_mode='random', fused_modconv=True, gain=1): # def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': # !!! custom sz = self.size if self.up == 2 and self.size is not None else x.shape[ 2:] noise = torch.randn([x.shape[0], 1, *sz], device=x.device) * self.noise_strength # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength # !!! custom noise size noise_size = self.size if self.up == 2 and self.size is not None and self.resolution > 4 else x.shape[ 2:] noise = fix_size(noise.unsqueeze(0).unsqueeze(0), noise_size, scale_type=self.scale_type)[0][0] # print(x.shape, noise.shape, self.size, self.up) flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d( x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, latmask=latmask, countHW=self.countHW, splitfine=self.splitfine, size=self.size, scale_type=self.scale_type, # !!! custom padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x, gain=1): w = self.weight * self.weight_gain b = self.bias.to(x.dtype) if self.bias is not None else None flip_weight = (self.up == 1) # slightly faster x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x): w = self.weight.to(dtype=x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(dtype=x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape( x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.cfg.use_noise and noise_mode == 'random': noise = torch.randn( [x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.cfg.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster if self.instance_norm: x = x / (x.std(dim=[2, 3], keepdim=True) + 1e-8 ) # [batch_size, c, h, w] if self.cfg.fmm.enabled: x = fmm_modulate_linear(x=x, weight=self.weight, styles=styles, noise=noise, activation=self.cfg.fmm.activation) else: x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x, w, mask, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) w_n, w_m, _ = w.shape styles = self.affine(w.view([w_n * w_m, -1])) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = x.repeat_interleave(w_m, 0) x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)