예제 #1
0
    def forward(self, batch_size: int, left_borders_idx: Tensor) -> Tensor:
        misc.assert_shape(left_borders_idx, [batch_size])

        noise = torch.randn(batch_size, self.channel_dim, self.resolution, self.resolution, device=left_borders_idx.device)
        out = self.coord_fuser(noise, left_borders_idx=left_borders_idx, memory_format=torch.contiguous_format)

        return out
예제 #2
0
    def forward(self, ws, mask=None, **block_kwargs):
        if ws.ndim == 3:
            ws = ws.unsqueeze(1)

        block_ws = []
        with torch.autograd.profiler.record_function('split_ws'):
            misc.assert_shape(ws, [None, None, self.num_ws, self.w_dim])
            ws = ws.to(torch.float32)
            w_idx = 0
            for res in self.block_resolutions:
                block = getattr(self, f'b{res}')
                block_ws.append(ws.narrow(2, w_idx, block.num_conv + block.num_torgb))
                w_idx += block.num_conv

        if mask is None:
            mask = ws.new_ones([1, ws.shape[1], self.img_resolution, self.img_resolution]) / ws.shape[1]

        misc.assert_shape(mask, [None, ws.shape[1], self.img_resolution, self.img_resolution])
        masks = [mask]
        for _ in range(len(self.block_resolutions) - 1):
            masks.insert(0, F.avg_pool2d(masks[0], 2))

        x = img = None
        for res, cur_ws, cur_mask in zip(self.block_resolutions, block_ws, masks):
            block = getattr(self, f'b{res}')
            x, img = block(x, img, cur_ws, cur_mask, **block_kwargs)
        return img
예제 #3
0
    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(
            x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)
        noise = None
        if self.use_noise and noise_mode == 'random':
            noise = torch.randn(
                [x.shape[0], 1, self.resolution, self.resolution],
                device=x.device) * self.noise_strength
            #noise += self.noise_const.expand_as(noise) * 0
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1)  # slightly faster
        x = modulated_conv2d(x=x,
                             weight=self.weight,
                             styles=styles,
                             noise=noise,
                             up=self.up,
                             padding=self.padding,
                             resample_filter=self.resample_filter,
                             flip_weight=flip_weight,
                             fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
예제 #4
0
    def forward(self, x, img, force_fp32=False):
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format

        # Input.
        if x is not None:
            misc.assert_shape(
                x, [None, self.in_channels, self.resolution, self.resolution])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # FromRGB.
        if self.in_channels == 0 or self.architecture == 'skip':
            misc.assert_shape(
                img,
                [None, self.img_channels, self.resolution, self.resolution])
            img = img.to(dtype=dtype, memory_format=memory_format)
            y = self.fromrgb(img)
            x = x + y if x is not None else y
            img = upfirdn2d.downsample2d(
                img,
                self.resample_filter) if self.architecture == 'skip' else None

        # Main layers.
        if self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x)
            x = self.conv1(x, gain=np.sqrt(0.5))
            x = y.add_(x)
        else:
            x = self.conv0(x)
            x = self.conv1(x)

        assert x.dtype == dtype
        return x, img
예제 #5
0
    def forward(self,
                z,
                c,
                truncation_psi=1,
                truncation_cutoff=None,
                skip_w_avg_update=False):
        # Embed, normalize, and concat inputs.
        x = None
        with torch.autograd.profiler.record_function('input'):
            if self.z_dim > 0:
                misc.assert_shape(z, [None, self.z_dim])
                x = normalize_2nd_moment(z.to(torch.float32))
            if self.c_dim > 0:
                misc.assert_shape(c, [None, self.c_dim])
                y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
                x = torch.cat([x, y], dim=1) if x is not None else y

        # Main layers.
        for idx in range(self.num_layers):
            layer = getattr(self, f'fc{idx}')
            x = layer(x)

        # Update moving average of W.
        if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
            with torch.autograd.profiler.record_function('update_w_avg'):
                self.w_avg.copy_(x.detach().mean(dim=0).lerp(
                    self.w_avg, self.w_avg_beta))
                self.w_cov.copy_((self.w_avg_beta * self.w_cov) + (
                    (self.w_avg_beta - self.w_avg_beta**2) *
                    (x.detach() - self.w_avg).T @ (x.detach() - self.w_avg)))
            with torch.autograd.profiler.record_function('update_w_avg'):
                self.w_avg.copy_(x.detach().mean(dim=0).lerp(
                    self.w_avg, self.w_avg_beta))
                self.w_cov.copy_((self.w_avg_beta * self.w_cov) + (
                    (self.w_avg_beta - self.w_avg_beta**2) *
                    (x.detach() - self.w_avg).T @ (x.detach() - self.w_avg)))

        # Broadcast.
        if self.num_ws is not None:
            with torch.autograd.profiler.record_function('broadcast'):
                x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

        # Apply truncation.
        if truncation_psi != 1:
            with torch.autograd.profiler.record_function('truncate'):
                assert self.w_avg_beta is not None
                if self.num_ws is None or truncation_cutoff is None:
                    x = self.w_avg.lerp(x, truncation_psi)
                else:
                    x[:, :truncation_cutoff] = self.w_avg.lerp(
                        x[:, :truncation_cutoff], truncation_psi)
        return x
def modulated_conv2d(
    x,                          # Input tensor of shape [batch_size, in_channels, in_height, in_width].
    weight,                     # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
    styles,                     # Modulation coefficients of shape [batch_size, in_channels].
    noise           = None,     # Optional noise tensor to add to the output activations.
    up              = 1,        # Integer upsampling factor.
    down            = 1,        # Integer downsampling factor.
    padding         = 0,        # Padding with respect to the upsampled image.
    resample_filter = None,     # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
    demodulate      = True,     # Apply weight demodulation?
    flip_weight     = True,     # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
    fused_modconv   = True,     # Perform modulation, convolution, and demodulation as a single fused operation?
):
    batch_size = x.shape[0]
    out_channels, in_channels, kh, kw = weight.shape
    misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
    misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
    misc.assert_shape(styles, [batch_size, in_channels]) # [NI]

    # Pre-normalize inputs to avoid FP16 overflow.
    if x.dtype == torch.float16 and demodulate:
        weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
        styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I

    # Calculate per-sample weights and demodulation coefficients.
    w = None
    dcoefs = None
    if demodulate or fused_modconv:
        w = weight.unsqueeze(0) # [NOIkk]
        w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
    if demodulate:
        dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
    if demodulate and fused_modconv:
        w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]

    # Execute by scaling the activations before and after the convolution.
    if not fused_modconv:
        x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
        x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
        if demodulate and noise is not None:
            x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
        elif demodulate:
            x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
        elif noise is not None:
            x = x.add_(noise.to(x.dtype))
        return x

    # Execute as one fused op using grouped convolution.
    with misc.suppress_tracer_warnings(): # this value will be treated as a constant
        batch_size = int(batch_size)
    misc.assert_shape(x, [batch_size, in_channels, None, None])
    x = x.reshape(1, -1, *x.shape[2:])
    w = w.reshape(-1, in_channels, kh, kw)
    x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
    x = x.reshape(batch_size, -1, *x.shape[2:])
    if noise is not None:
        x = x.add_(noise)
    return x
    def forward(self, ws, c=None, **block_kwargs):
        block_ws = []
        with torch.autograd.profiler.record_function('split_ws'):
            misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
            ws = ws.to(torch.float32)
            w_idx = 0
            for res in self.block_resolutions:
                block = getattr(self, f'b{res}')
                block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
                w_idx += block.num_conv

        x = img = None
        for res, cur_ws in zip(self.block_resolutions, block_ws):
            block = getattr(self, f'b{res}')
            x, img = block(x, img, cur_ws, **block_kwargs)
        return img
예제 #8
0
    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(
            x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)

        noise = None
        if self.cfg.use_noise and noise_mode == 'random':
            noise = torch.randn(
                [x.shape[0], 1, self.resolution, self.resolution],
                device=x.device) * self.noise_strength
        if self.cfg.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1)  # slightly faster

        if self.instance_norm:
            x = x / (x.std(dim=[2, 3], keepdim=True) + 1e-8
                     )  # [batch_size, c, h, w]

        if self.cfg.fmm.enabled:
            x = fmm_modulate_linear(x=x,
                                    weight=self.weight,
                                    styles=styles,
                                    noise=noise,
                                    activation=self.cfg.fmm.activation)
        else:
            x = modulated_conv2d(x=x,
                                 weight=self.weight,
                                 styles=styles,
                                 noise=noise,
                                 up=self.up,
                                 padding=self.padding,
                                 resample_filter=self.resample_filter,
                                 flip_weight=flip_weight,
                                 fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
예제 #9
0
    def forward(self, batch_size: int, shifts: Optional[Tensor]=None) -> Tensor:
        x = self.const_input.unsqueeze(0).repeat([batch_size, 1, 1, 1]) # [b, c, h, w]

        if shifts is not None:
            misc.assert_shape(shifts, [batch_size, 2])
            assert shifts.max().item() <= 1.0
            assert shifts.min().item() >= -1.0

            coords = generate_coords(batch_size, self.const_input.shape[1], device=x.device, align_corners=True) # [b, 2, h, w]

            # # Applying the shift
            # coords = coords + shifts.view(batch_size, 2, 1, 1) # [b, 2, h, w]

            # # Converting into F.grid_sample coords:
            # # 1. Convert the range
            # coords = coords + 1 # [-1, 1] => [0, 2]
            # # 2. Perform padding_mode=replicate
            # # coords[coords > 0] = coords[coords > 0] % (2 + 1e-12)
            # # coords[coords < 0] = -(-coords[coords < 0] % 2) + 2 + (1e-12)
            # # 3. Convert back to [-1, 1] range
            # coords = coords - 1 # [0, 2] => [-1, 1]
            # # 4. F.grid_sample uses flipped coordinates (TODO: should we too?)
            # coords[:, 1] = coords[:, 1] * -1.0
            # # 5. It also uses different shape
            # coords = coords.permute(0, 2, 3, 1) # [b, h, w, 2]

            # Performing a slower, but less error-prone approach
            # (convert shifts from [-1, 1] to [-2, 2], so we are now [-3, 3])
            coords = coords + 2 * shifts.view(batch_size, 2, 1, 1) # [b, 2, h, w]
            coords = coords / 3 # [-3, 3] => [-1, 1] range
            coords = coords.permute(0, 2, 3, 1)
            assert coords.min().item() >= -1
            assert coords.max().item() <= 1

            x = torch.cat([x, x, x], dim=3) # [b, c, h, w * 3]
            x = F.grid_sample(x, coords, mode='bilinear', align_corners=True) # [b, c, h, w]

            # torch.save(coords.detach().cpu(), '/tmp/trash/coords')
            # torch.save(x.detach().cpu(), '/tmp/trash/x')
            # torch.save(self.const_input.detach().cpu(), '/tmp/trash/const_input')

            # assert torch.allclose(x[0], self.const_input, atol=1e-4)

        return x
예제 #10
0
def fmm_modulate(
    conv_weight: Tensor,
    fmm_weights: nn.Module,
    fmm_mod_type: str='mult',
    demodulate: bool=False,
    fmm_add_weight: float=1.0,
    activation: Optional[str]=None) -> Tensor:
    """
    Applies FMM fmm_weights to a given conv weight tensor
    """
    batch_size, out_channels, in_channels, kh, kw = conv_weight.shape

    assert fmm_weights.shape[1] % (in_channels + out_channels) == 0

    rank = fmm_weights.shape[1] // (in_channels + out_channels)
    lhs = fmm_weights[:, : rank * out_channels].view(batch_size, out_channels, rank)
    rhs = fmm_weights[:, rank * out_channels :].view(batch_size, rank, in_channels)

    modulation = lhs @ rhs # [batch_size, out_channels, in_channels]
    modulation = modulation / np.sqrt(rank)
    misc.assert_shape(modulation, [batch_size, out_channels, in_channels])
    modulation = modulation.unsqueeze(3).unsqueeze(4) # [batch_size, out_channels, in_channels, 1, 1]

    if activation == "tanh":
        modulation = modulation.tanh()
    elif activation in ['linear', None]:
        pass
    elif activation == 'sigmoid':
        modulation = modulation.sigmoid() - 0.5
    else:
        raise NotImplementedError

    if fmm_mod_type == 'mult':
        out = conv_weight * (modulation + 1.0)
    elif fmm_mod_type == 'add':
        out = conv_weight + fmm_add_weight * modulation
    else:
        raise NotImplementedError

    if demodulate:
        out = out / out.norm(dim=[2, 3, 4], keepdim=True)

    return out
예제 #11
0
    def forward(self, x, img, cmap, force_fp32=False):
        misc.assert_shape(
            x, [None, self.in_channels, self.resolution, self.resolution
                ])  # [NCHW]
        _ = force_fp32  # unused
        dtype = torch.float32
        memory_format = torch.contiguous_format

        # FromRGB.
        x = x.to(dtype=dtype, memory_format=memory_format)
        if self.architecture == 'skip':
            misc.assert_shape(
                img,
                [None, self.img_channels, self.resolution, self.resolution])
            img = img.to(dtype=dtype, memory_format=memory_format)
            x = x + self.fromrgb(img)

        # Main layers.
        if self.mbstd is not None:
            x = self.mbstd(x)
        x = self.conv(x)
        x = self.fc(x.flatten(1))
        x = self.out(x)

        # Conditioning.
        if self.cmap_dim > 0:
            misc.assert_shape(cmap, [None, self.cmap_dim])
            x = (x * cmap).sum(dim=1,
                               keepdim=True) * (1 / np.sqrt(self.cmap_dim))

        assert x.dtype == dtype
        return x
예제 #12
0
def fast_bilinear_mult_row(x: Tensor, styles: Tensor, shifts: Optional[Tensor]=None) -> Tensor:
    b, c, h, w = x.shape
    context_size = 2
    misc.assert_shape(styles, [b, c, context_size + 1])

    centers = shifts
    if centers is None:
        centers = torch.zeros(b, 2, dtype=styles.dtype, device=styles.device)

    misc.assert_shape(centers, [b, 2])
    assert centers.min().item() >= -1.0
    assert centers.max().item() >= -1.0

    # Centers are [-1, 1] range, but w_before/w_after positions correspond to -2/2.
    # Constructing the bounds for each center
    # The size of the square is 2: it is in [-1, 1] x [-1, 1]
    # Bounds correspond to left and right borders
    assert context_size == 2
    bounds = torch.stack([
        torch.stack([centers[:, 0] - 1, centers[:, 1]], dim=1),
        torch.stack([centers[:, 0] + 1, centers[:, 1]], dim=1)
    ], dim=1) # [b, 2, 2]
    bounds = bounds.unsqueeze(1) # [b, 1, 2, 2] == [b, h, w, 2]

    # Now, grid sample assume [-1, 1] range, so adjust:
    bounds.mul_(0.5)

    # Also, for F.grid_sample we need to flip y coordinate
    bounds[:, :, :, 1].mul_(-1.0)

    # Now, we can get our interpolated embeddings
    w_bounds = F.grid_sample(styles.unsqueeze(2), bounds.to(styles.dtype), mode='bilinear', align_corners=True) # [b, c, 1, 2]

    # Now, we can interpolate and modulate
    modulation = F.interpolate(w_bounds, size=(1, w), mode='bilinear', align_corners=True) # [b, c, 1, w]
    x = x * modulation # [b, c, h, w]

    # print('PERFORMED fast_bilinear_mult_row')

    return x
예제 #13
0
    def forward(self, x, w, mask, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
        w_n, w_m, _ = w.shape
        styles = self.affine(w.view([w_n * w_m, -1]))

        noise = None
        if self.use_noise and noise_mode == 'random':
            noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1) # slightly faster
        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None

        x = x.repeat_interleave(w_m, 0)
        x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
                padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
        x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
        return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)
예제 #14
0
def fast_manual_bilinear_mult_row(x: Tensor, styles: Tensor, left_borders_idx: Tensor, grid_size: int, w_coord_dist: float, w_lerp_multiplier: float=1.0) -> Tensor:
    b, c, h, w = x.shape
    misc.assert_shape(styles, [b, 3, c])
    misc.assert_shape(left_borders_idx, [b])

    w_dist = int(0.5 * w_coord_dist * w)
    interp_coefs = torch.linspace(1 / (2 * w_dist), 1 - 1 / (2 * w_dist), w_dist, device=x.device, dtype=styles.dtype) # [w_dist]
    interp_coefs = interp_coefs * w_lerp_multiplier
    interp_coefs = interp_coefs.view(1, w_dist, 1) # [1, w_dist, 1]
    styles_grid_left = styles[:, 0].unsqueeze(1) * (w_lerp_multiplier - interp_coefs) + styles[:, 1].unsqueeze(1) * interp_coefs # [b, w_dist, c]
    styles_grid_right = styles[:, 1].unsqueeze(1) * (w_lerp_multiplier - interp_coefs) + styles[:, 2].unsqueeze(1) * interp_coefs # [b, w_dist, c]
    styles_grid = torch.cat([styles_grid_left, styles_grid_right], dim=1).to(x.dtype) # [b, 2 * w_dist, c]

    # Left borders were randomly sampled in [0, 2 * w_dist - w] integer range
    # We use them to select the corresponding styles
    patch_size = w // grid_size
    batch_idx = torch.arange(b, device=x.device).view(-1, 1).repeat(1, w) # [b, w]
    grid_idx = (left_borders_idx.unsqueeze(1) * patch_size) + torch.arange(w, device=x.device).view(1, -1) # [b, w]
    latents = styles_grid[batch_idx, grid_idx].permute(0, 2, 1) # [b, c, w]
    x = x * latents.unsqueeze(2) # [b, c, h, w]

    return x
예제 #15
0
파일: networks.py 프로젝트: LordHui/inr-gan
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format

        if fused_modconv is None:
            with misc.suppress_tracer_warnings(): # this value will be treated as a constant
                fused_modconv = (not self.training) and (dtype == torch.float32 or (isinstance(x, Tensor) and int(x.shape[0]) == 1))

        # Input.
        if self.in_channels == 0:
            conv1_w = next(w_iter)
            x = self.input(ws.shape[0], conv1_w, device=ws.device, dtype=dtype, memory_format=memory_format)
        else:
            misc.assert_shape(x, [None, self.in_channels, self.input_resolution, self.input_resolution])
            x = x.to(dtype=dtype, memory_format=memory_format)

        x = maybe_upsample(x, self.cfg.upsampling_mode, self.up)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, conv1_w, fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            conv0_w = next(w_iter)

            if self.coord_fuser is not None:
                x = self.coord_fuser(x, conv0_w, dtype=dtype, memory_format=memory_format)

            x = self.conv0(x, conv0_w, fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        if not self.extra_convs is None:
            for conv, w in zip(self.extra_convs, w_iter):
                x = conv(x, w, fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [None, self.img_channels, self.input_resolution, self.input_resolution])

            if self.up == 2:
                if self.cfg.upsampling_mode is None:
                    img = upfirdn2d.upsample2d(img, self.resample_filter)
                else:
                    img = maybe_upsample(img, self.cfg.upsampling_mode, 2)

        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
예제 #16
0
def fast_bilinear_mult(x, styles):
    """
    x: [b, c, h, w],
    styles: [b, c, 2, 2]
    """
    b, c, h, w = x.shape
    misc.assert_shape(styles, [b, c, 2, 2])

    kwargs = dict(device=x.device, dtype=x.dtype)
    top_to_bottom = torch.linspace(1, 0, h, **kwargs).unsqueeze(1)
    left_to_right = torch.linspace(1, 0, w, **kwargs).unsqueeze(0)
    coefs_11 = top_to_bottom * left_to_right # [h, w]
    coefs_12 = top_to_bottom * (1.0 - left_to_right) # [h, w]
    coefs_21 = (1.0 - top_to_bottom) * left_to_right  # [h, w]
    coefs_22 = (1.0 - top_to_bottom) * (1.0 - left_to_right) # [h, w]
    coefs = torch.stack([coefs_11, coefs_12, coefs_21, coefs_22]) # [4, h, w]
    coefs = coefs.unsqueeze(0).unsqueeze(2) # [1, 4, 1, h, w]
    xs = (x.unsqueeze(1) * coefs) # [b, 4, c, h, w]
    styles = styles.permute(0, 2, 3, 1).view(b, 4, c) # [b, 4, c]
    styles = styles.view(b, 4, c, 1, 1) # [b, 4, c, 1, 1]
    y = (xs * styles).sum(dim=1) # [b, c, h, w]

    return y
예제 #17
0
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings(): # this value will be treated as a constant
                fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
# !!! custom
            misc.assert_shape(x, [None, self.in_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8])
            # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
# !!! custom
            misc.assert_shape(img, [None, self.img_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8])
            # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
        _ = update_emas # unused
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        if ws.device.type != 'cuda':
            force_fp32 = True
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            fused_modconv = self.fused_modconv_default
        if fused_modconv == 'inference_only':
            fused_modconv = (not self.training)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
            misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
예제 #19
0
    def forward(self, batch_size: int, w: Tensor, w_context: Tensor, left_borders_idx: Tensor) -> Tensor:
        misc.assert_shape(w, [batch_size, self.w_dim])
        misc.assert_shape(w_context, [batch_size, 2, self.w_dim])
        misc.assert_shape(left_borders_idx, [batch_size])

        # Computing the global features
        w_all = torch.stack([w_context[:, 0], w, w_context[:, 1]], dim=1) # [b, 3, w_dim]
        styles = self.affine(w_all.view(-1, self.w_dim)).view(batch_size, 3, self.channel_dim) # [b, 2, c]
        raw_const_inputs = self.input_column.unsqueeze(0).unsqueeze(3).repeat(batch_size, 1, 1, self.resolution) # [b, c, h, w]
        latents = fast_manual_bilinear_mult_row(raw_const_inputs, styles, left_borders_idx, self.grid_size, self.w_coord_dist, self.w_lerp_multiplier)

        # Ok, now for each cell in the grid we need to compute its high-frequency coordinates
        # Otherwise, it will be too difficult for the model to understand the relative positions
        coords = generate_shifted_coords(left_borders_idx, self.resolution, self.grid_size, self.w_coord_dist, device=w.device)
        bases = self.basis.unsqueeze(0).repeat(batch_size, 1, 1) # [batch_size, dim, 2]
        raw_coord_embs = torch.einsum('bdc,bcxy->bdxy', bases, coords) # [batch_size, dim, img_size, img_size]
        coord_embs = torch.cat([raw_coord_embs.sin(), raw_coord_embs.cos()], dim=1) # [batch_size, dim * 2, img_size, img_size]

        # Computing final inputs
        inputs = torch.cat([latents, coord_embs], dim=1) # [b, c, grid_size, grid_size]

        return inputs
예제 #20
0
    def forward(self,
                x: Tensor,
                w: Tensor = None,
                dtype=None,
                memory_format=None) -> Tensor:
        """
        Dims:
            @arg x is [batch_size, in_channels, img_size, img_size]
            @arg w is [batch_size, w_dim]
            @return out is [batch_size, in_channels + fourier_dim + cips_dim, img_size, img_size]
        """
        assert memory_format is torch.contiguous_format

        if self.cfg.fallback:
            return x

        batch_size, in_channels, img_size = x.shape[:3]
        out = x

        if self.use_full_cache and (not self._full_cache is None) and (self._full_cache.device == x.device) and \
           (self._full_cache.shape == (batch_size, self.get_total_dim(), img_size, img_size)):
            return torch.cat([x, self._full_cache], dim=1)

        if (not self._fourier_embs_cache is None) and (self._fourier_embs_cache.device == x.device) and \
           (self._fourier_embs_cache.shape == (batch_size, self.get_total_dim() - self.const_emb_size, img_size, img_size)):
            out = torch.cat([out, self._fourier_embs_cache], dim=1)
        else:
            raw_embs = []
            raw_coords = generate_coords(
                batch_size, img_size,
                x.device)  # [batch_size, coord_dim, img_size, img_size]

            if self.use_raw_coords:
                out = torch.cat([out, raw_coords], dim=1)

            if self.log_emb_size > 0:
                log_bases = self.log_basis.unsqueeze(0).repeat(
                    batch_size, 1, 1)  # [batch_size, log_emb_size, 2]
                raw_log_embs = torch.einsum(
                    'bdc,bcxy->bdxy', log_bases, raw_coords
                )  # [batch_size, log_emb_size, img_size, img_size]
                raw_embs.append(raw_log_embs)

            if self.random_emb_size > 0:
                random_bases = self.random_basis.unsqueeze(0).repeat(
                    batch_size, 1, 1)  # [batch_size, random_emb_size, 2]
                raw_random_embs = torch.einsum(
                    'bdc,bcxy->bdxy', random_bases, raw_coords
                )  # [batch_size, random_emb_size, img_size, img_size]
                raw_embs.append(raw_random_embs)

            if self.shared_emb_size > 0:
                shared_bases = self.shared_basis.unsqueeze(0).repeat(
                    batch_size, 1, 1)  # [batch_size, shared_emb_size, 2]
                raw_shared_embs = torch.einsum(
                    'bdc,bcxy->bdxy', shared_bases, raw_coords
                )  # [batch_size, shared_emb_size, img_size, img_size]
                raw_embs.append(raw_shared_embs)

            if self.predictable_emb_size > 0:
                misc.assert_shape(w, [batch_size, None])
                mod = self.affine(w)  # [batch_size, W_size + b_size]
                W = self.fourier_scale * mod[:, :self.
                                             W_size]  # [batch_size, W_size]
                W = W.view(batch_size, self.predictable_emb_size,
                           self.cfg.coord_dim
                           )  # [batch_size, predictable_emb_size, coord_dim]
                bias = mod[:, self.W_size:].view(
                    batch_size, self.predictable_emb_size, 1,
                    1)  # [batch_size, predictable_emb_size, 1]
                raw_predictable_embs = (
                    torch.einsum('bdc,bcxy->bdxy', W, raw_coords) + bias
                )  # [batch_size, predictable_emb_size, img_size, img_size]
                raw_embs.append(raw_predictable_embs)

            if len(raw_embs) > 0:
                raw_embs = torch.cat(
                    raw_embs, dim=1
                )  # [batch_suze, log_emb_size + random_emb_size + predictable_emb_size, img_size, img_size]
                raw_embs = raw_embs.contiguous(
                )  # [batch_suze, -1, img_size, img_size]
                out = torch.cat([
                    out,
                    raw_embs.sin().to(dtype=dtype, memory_format=memory_format)
                ],
                                dim=1)  # [batch_size, -1, img_size, img_size]

                if self.use_cosine:
                    out = torch.cat(
                        [
                            out,
                            raw_embs.cos().to(dtype=dtype,
                                              memory_format=memory_format)
                        ],
                        dim=1)  # [batch_size, -1, img_size, img_size]

        if self.predictable_emb_size == 0 and self.shared_emb_size == 0 and out.shape[
                1] > x.shape[1]:
            self._fourier_embs_cache = out[:, x.shape[1]:].detach()

        if self.const_emb_size > 0:
            const_embs = self.const_embs.repeat([batch_size, 1, 1, 1])
            const_embs = const_embs.to(dtype=dtype,
                                       memory_format=memory_format)
            out = torch.cat(
                [out, const_embs],
                dim=1)  # [batch_size, total_dim, img_size, img_size]

        if self.use_full_cache and self.predictable_emb_size == 0 and self.shared_emb_size == 0 and out.shape[
                1] > x.shape[1]:
            self._full_cache = out[:, x.shape[1]:].detach()

        return out
예제 #21
0
    def forward(self,
                x,
                img,
                ws,
                force_fp32=False,
                fused_modconv=None,
                **layer_kwargs):
        misc.assert_shape(ws,
                          [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings(
            ):  # this value will be treated as a constant
                fused_modconv = (not self.training) and (
                    dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
            misc.assert_shape(x, [
                None, self.in_channels, self.resolution // 2,
                self.resolution // 2
            ])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           gain=np.sqrt(0.5),
                           **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [
                None, self.img_channels + self.segmentation_channels,
                self.resolution // 2, self.resolution // 2
            ])
            img = upfirdn2d.upsample2d(img, self.resample_filter)

        if self.is_last or self.architecture == 'skip':
            w_temp = next(w_iter)
            rgb = self.torgb(x, w_temp, fused_modconv=fused_modconv)
            rgb = rgb.to(dtype=torch.float32,
                         memory_format=torch.contiguous_format)
            segmentation = self.tosegmentation(x,
                                               w_temp,
                                               fused_modconv=fused_modconv)
            newImg = torch.cat((rgb, segmentation), dim=1)
            img = img.add_(newImg) if img is not None else newImg

            if self.is_last:
                originalSegmentation = img[:, 3:]
                maxs = torch.max(originalSegmentation, dim=1)[0].unsqueeze(1)
                afterSubtraction = originalSegmentation - maxs + self.eps
                finalArray = torch.round(
                    torch.max(self.zeroTensor, afterSubtraction) / self.eps)

                img[:, 3:] = finalArray

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
def modulated_conv2d(
        x,  # input, shape=[batch_size, in_channels, in_height, in_width]
        weight,  # weights, shape=[out_channels, in_channels, kernel_height, kernel_width]
        styles,  # modulation co-efficients, shape=[batch_size, in_channels]
        noise=None,  # to add noise to the output activations
        up=1,  # upsampling factpr
        down=1,  # downsampling factor
        padding=0,  # padding as per upsampled image
        resample_filter=None,
        demodulate=True,  # Weight demodulation
        flip_weight=True,
        fused_modconv=True,  # To perform modulation
):
    batch_size = x.shape[0]
    out_channels, in_channels, kh, kw = weight.shape

    misc.assert_shape(weight, [out_channels, in_channels, kh, kw])
    misc.assert_shape(x, [batch_size, in_channels, None, None])
    misc.assert_shape(styles, [batch_size, in_channels])

    # Normalize inputs
    if x.dtype == torch.float16 and demodulate:
        weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(
            float('inf'), dim=[1, 2, 3], keepdim=True))
        styles = styles / styles.norm(float('inf'), dim=1, keepdim=True)

    # Calculate sample weights and demodultion coefficients
    w = None
    demod_coeff = None

    if demodulate or fused_modconv:
        w = weight.unsqueeze(0)
        w = w + styles.reshape(batch_size, 1, -1, 1, 1)

    if demodulate:
        demod_coeff = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt()

    if demodulate and fused_modconv:
        w = w * demod_coeff.reshape(batch_size, -1, 1, 1, 1)

    # Modulation execution by scaling activations
    if not fused_modconv:
        x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
        x = conv2d_resample.conv2d_resample(
            x=x,
            w=weight.to(x.dtype),
            f=resample_filter,
            up=up,
            down=down,
            padding=padding,
            flip_weight=flip_weight,
        )

        if demodulate and noise is not None:
            x = fma.fma(x,
                        demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1),
                        noise.to(x.dtype))
        elif demodulate:
            x = x * demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1)
        elif noise is not None:
            x = x.add_(noise.to(x.dtype))

        return x

    with misc.suppress_tracer_warnings():
        batch_size = int(batch_size)

    misc.assert_shape(x, [batch_size, in_channels, None, None])
    x = x.reshape(1, 1, *x.shape[2:])
    w = w.reshape(-1, in_channels, kh, kw)

    x = conv2d_resample.conv2d_resample(
        x=x,
        w=w.to(x.dtype),
        f=resample_filter,
        up=up,
        down=down,
        padding=padding,
        groups=batch_size,
        flip_weight=flip_weight,
    )
    x = x.reshape(batch_size, -1, *x.shape[2:])

    if noise is not None:
        x = x.add_(noise)

    return x