Ejemplo n.º 1
0
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
    _N, _C, H, W = x.shape
    mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)

    # Construct filter.
    f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
    assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
    p = f.shape[0] // 2

    # Construct sampling grid.
    theta = mat.inverse()
    theta[:2, 2] *= 2
    theta[0, 2] += 1 / up / W
    theta[1, 2] += 1 / up / H
    theta[0, :] *= W / (W + p / up * 2)
    theta[1, :] *= H / (H + p / up * 2)
    theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
    g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)

    # Resample image.
    y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
    z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)

    # Form mask.
    m = torch.zeros_like(y)
    c = p * 2 + 1
    m[:, :, c:-c, c:-c] = 1
    m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
    return z, m
Ejemplo n.º 2
0
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format

        if fused_modconv is None:
            with misc.suppress_tracer_warnings(): # this value will be treated as a constant
                fused_modconv = (not self.training) and (dtype == torch.float32 or (isinstance(x, Tensor) and int(x.shape[0]) == 1))

        # Input.
        if self.in_channels == 0:
            conv1_w = next(w_iter)
            x = self.input(ws.shape[0], conv1_w, device=ws.device, dtype=dtype, memory_format=memory_format)
        else:
            misc.assert_shape(x, [None, self.in_channels, self.input_resolution, self.input_resolution])
            x = x.to(dtype=dtype, memory_format=memory_format)

        x = maybe_upsample(x, self.cfg.upsampling_mode, self.up)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, conv1_w, fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            conv0_w = next(w_iter)

            if self.coord_fuser is not None:
                x = self.coord_fuser(x, conv0_w, dtype=dtype, memory_format=memory_format)

            x = self.conv0(x, conv0_w, fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        if not self.extra_convs is None:
            for conv, w in zip(self.extra_convs, w_iter):
                x = conv(x, w, fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [None, self.img_channels, self.input_resolution, self.input_resolution])

            if self.up == 2:
                if self.cfg.upsampling_mode is None:
                    img = upfirdn2d.upsample2d(img, self.resample_filter)
                else:
                    img = maybe_upsample(img, self.cfg.upsampling_mode, 2)

        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
        _ = update_emas # unused
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        if ws.device.type != 'cuda':
            force_fp32 = True
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            fused_modconv = self.fused_modconv_default
        if fused_modconv == 'inference_only':
            fused_modconv = (not self.training)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
            misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
Ejemplo n.º 4
0
    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings(): # this value will be treated as a constant
                fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
# !!! custom
            misc.assert_shape(x, [None, self.in_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8])
            # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
# !!! custom
            misc.assert_shape(img, [None, self.img_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8])
            # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
Ejemplo n.º 5
0
    def forward(self, images, debug_percentile=None):
        assert isinstance(images, torch.Tensor) and images.ndim == 4
        batch_size, num_channels, height, width = images.shape
        device = images.device
        if debug_percentile is not None:
            debug_percentile = torch.as_tensor(debug_percentile,
                                               dtype=torch.float32,
                                               device=device)

        # -------------------------------------
        # Select parameters for pixel blitting.
        # -------------------------------------

        # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
        I_3 = torch.eye(3, device=device)
        G_inv = I_3

        # Apply x-flip with probability (xflip * strength).
        if self.xflip > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size], device=device) < self.xflip * self.p,
                i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)

        # Apply 90 degree rotations with probability (rotate90 * strength).
        if self.rotate90 > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 4)
            i = torch.where(
                torch.rand([batch_size], device=device) <
                self.rotate90 * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 4))
            G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)

        # Apply integer translation with probability (xint * strength).
        if self.xint > 0:
            t = (torch.rand([batch_size, 2], device=device) * 2 -
                 1) * self.xint_max
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xint * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(t,
                                    (debug_percentile * 2 - 1) * self.xint_max)
            G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width),
                                            torch.round(t[:, 1] * height))

        # --------------------------------------------------------
        # Select parameters for general geometric transformations.
        # --------------------------------------------------------

        # Apply isotropic scaling with probability (scale * strength).
        if self.scale > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.scale_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.scale * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.scale_std))
            G_inv = G_inv @ scale2d_inv(s, s)

        # Apply pre-rotation with probability p_rot.
        p_rot = 1 - torch.sqrt(
            (1 - self.rotate * self.p).clamp(0, 1))  # P(pre OR post) = p
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.rotate_max)
            G_inv = G_inv @ rotate2d_inv(-theta)  # Before anisotropic scaling.

        # Apply anisotropic scaling with probability (aniso * strength).
        if self.aniso > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.aniso_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.aniso * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.aniso_std))
            G_inv = G_inv @ scale2d_inv(s, 1 / s)

        # Apply post-rotation with probability p_rot.
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.zeros_like(theta)
            G_inv = G_inv @ rotate2d_inv(-theta)  # After anisotropic scaling.

        # Apply fractional translation with probability (xfrac * strength).
        if self.xfrac > 0:
            t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xfrac * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(
                    t,
                    torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
            G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height)

        # ----------------------------------
        # Execute geometric transformations.
        # ----------------------------------

        # Execute if the transform is not identity.
        if G_inv is not I_3:
            # Calculate padding.
            cx = (width - 1) / 2
            cy = (height - 1) / 2
            cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1],
                        device=device)  # [idx, xyz]
            cp = G_inv @ cp.t()  # [batch, xyz, idx]
            Hz_pad = self.Hz_geom.shape[0] // 4
            margin = cp[:, :2, :].permute(1, 0,
                                          2).flatten(1)  # [xy, batch * idx]
            margin = torch.cat([-margin,
                                margin]).max(dim=1).values  # [x0, y0, x1, y1]
            margin = margin + misc.constant(
                [Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
            margin = margin.max(misc.constant([0, 0] * 2, device=device))
            margin = margin.min(
                misc.constant([width - 1, height - 1] * 2, device=device))
            mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)

            # Pad image and adjust origin.
            images = torch.nn.functional.pad(input=images,
                                             pad=[mx0, mx1, my0, my1],
                                             mode='reflect')
            G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv

            # Upsample.
            images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
            G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(
                2, 2, device=device)
            G_inv = translate2d(-0.5, -0.5,
                                device=device) @ G_inv @ translate2d_inv(
                                    -0.5, -0.5, device=device)

            # Execute transformation.
            shape = [
                batch_size, num_channels, (height + Hz_pad * 2) * 2,
                (width + Hz_pad * 2) * 2
            ]
            G_inv = scale2d(2 / images.shape[3],
                            2 / images.shape[2],
                            device=device) @ G_inv @ scale2d_inv(
                                2 / shape[3], 2 / shape[2], device=device)
            grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :],
                                                   size=shape,
                                                   align_corners=False)
            images = grid_sample_gradfix.grid_sample(images, grid)

            # Downsample and crop.
            images = upfirdn2d.downsample2d(x=images,
                                            f=self.Hz_geom,
                                            down=2,
                                            padding=-Hz_pad * 2,
                                            flip_filter=True)

        # --------------------------------------------
        # Select parameters for color transformations.
        # --------------------------------------------

        # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
        I_4 = torch.eye(4, device=device)
        C = I_4

        # Apply brightness with probability (brightness * strength).
        if self.brightness > 0:
            b = torch.randn([batch_size], device=device) * self.brightness_std
            b = torch.where(
                torch.rand([batch_size], device=device) <
                self.brightness * self.p, b, torch.zeros_like(b))
            if debug_percentile is not None:
                b = torch.full_like(
                    b,
                    torch.erfinv(debug_percentile * 2 - 1) *
                    self.brightness_std)
            C = translate3d(b, b, b) @ C

        # Apply contrast with probability (contrast * strength).
        if self.contrast > 0:
            c = torch.exp2(
                torch.randn([batch_size], device=device) * self.contrast_std)
            c = torch.where(
                torch.rand([batch_size], device=device) <
                self.contrast * self.p, c, torch.ones_like(c))
            if debug_percentile is not None:
                c = torch.full_like(
                    c,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.contrast_std))
            C = scale3d(c, c, c) @ C

        # Apply luma flip with probability (lumaflip * strength).
        v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3),
                          device=device)  # Luma axis.
        if self.lumaflip > 0:
            i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.lumaflip * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            C = (I_4 - 2 * v.ger(v) * i) @ C  # Householder reflection.

        # Apply hue rotation with probability (hue * strength).
        if self.hue > 0 and num_channels > 1:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.hue_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < self.hue * self.p,
                theta, torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.hue_max)
            C = rotate3d(v, theta) @ C  # Rotate around v.

        # Apply saturation with probability (saturation * strength).
        if self.saturation > 0 and num_channels > 1:
            s = torch.exp2(
                torch.randn([batch_size, 1, 1], device=device) *
                self.saturation_std)
            s = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.saturation * self.p, s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.saturation_std))
            C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C

        # ------------------------------
        # Execute color transformations.
        # ------------------------------

        # Execute if the transform is not identity.
        if C is not I_4:
            images = images.reshape([batch_size, num_channels, height * width])
            if num_channels == 4:
                alpha = images[:,
                               3, :].unsqueeze(dim=1)  # [batch_size, 1, ...]
                rgb = C[:, :3, :
                        3] @ images[:, :3, :] + C[:, :3,
                                                  3:]  # [batch_size, 3, ...]
                images = torch.cat([rgb, alpha], dim=1)  # [batch_size, 4, ...]
            elif num_channels == 3:
                images = C[:, :3, :3] @ images + C[:, :3, 3:]
            elif num_channels == 1:
                C = C[:, :3, :].mean(dim=1, keepdims=True)
                images = images * C[:, :, :3].sum(dim=2,
                                                  keepdims=True) + C[:, :, 3:]
            else:
                raise ValueError(
                    'Image must be RGBA (4 channels), RGB (3 channels) or L (1 channel)'
                )
            images = images.reshape([batch_size, num_channels, height, width])

        # ----------------------
        # Image-space filtering.
        # ----------------------

        if self.imgfilter > 0:
            num_bands = self.Hz_fbank.shape[0]
            assert len(self.imgfilter_bands) == num_bands
            expected_power = misc.constant(
                np.array([10, 1, 1, 1]) / 13,
                device=device)  # Expected power spectrum (1/f).

            # Apply amplification for each band with probability (imgfilter * strength * band_strength).
            g = torch.ones([batch_size, num_bands],
                           device=device)  # Global gain vector (identity).
            for i, band_strength in enumerate(self.imgfilter_bands):
                t_i = torch.exp2(
                    torch.randn([batch_size], device=device) *
                    self.imgfilter_std)
                t_i = torch.where(
                    torch.rand([batch_size], device=device) <
                    self.imgfilter * self.p * band_strength, t_i,
                    torch.ones_like(t_i))
                if debug_percentile is not None:
                    t_i = torch.full_like(
                        t_i,
                        torch.exp2(
                            torch.erfinv(debug_percentile * 2 - 1) *
                            self.imgfilter_std)
                    ) if band_strength > 0 else torch.ones_like(t_i)
                t = torch.ones([batch_size, num_bands],
                               device=device)  # Temporary gain vector.
                t[:, i] = t_i  # Replace i'th element.
                t = t / (expected_power * t.square()).sum(
                    dim=-1, keepdims=True).sqrt()  # Normalize power.
                g = g * t  # Accumulate into global gain.

            # Construct combined amplification filter.
            Hz_prime = g @ self.Hz_fbank  # [batch, tap]
            Hz_prime = Hz_prime.unsqueeze(1).repeat(
                [1, num_channels, 1])  # [batch, channels, tap]
            Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1,
                                         -1])  # [batch * channels, 1, tap]

            # Apply filter.
            p = self.Hz_fbank.shape[1] // 2
            images = images.reshape(
                [1, batch_size * num_channels, height, width])
            images = torch.nn.functional.pad(input=images,
                                             pad=[p, p, p, p],
                                             mode='reflect')
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(2),
                                           groups=batch_size * num_channels)
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(3),
                                           groups=batch_size * num_channels)
            images = images.reshape([batch_size, num_channels, height, width])

        # ------------------------
        # Image-space corruptions.
        # ------------------------

        # Apply additive RGB noise with probability (noise * strength).
        if self.noise > 0:
            sigma = torch.randn([batch_size, 1, 1, 1],
                                device=device).abs() * self.noise_std
            sigma = torch.where(
                torch.rand([batch_size, 1, 1, 1], device=device) <
                self.noise * self.p, sigma, torch.zeros_like(sigma))
            if debug_percentile is not None:
                sigma = torch.full_like(
                    sigma,
                    torch.erfinv(debug_percentile) * self.noise_std)
            images = images + torch.randn(
                [batch_size, num_channels, height, width],
                device=device) * sigma

        # Apply cutout with probability (cutout * strength).
        if self.cutout > 0:
            size = torch.full([batch_size, 2, 1, 1, 1],
                              self.cutout_size,
                              device=device)
            size = torch.where(
                torch.rand([batch_size, 1, 1, 1, 1], device=device) <
                self.cutout * self.p, size, torch.zeros_like(size))
            center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
            if debug_percentile is not None:
                size = torch.full_like(size, self.cutout_size)
                center = torch.full_like(center, debug_percentile)
            coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
            coord_y = torch.arange(height,
                                   device=device).reshape([1, 1, -1, 1])
            mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >=
                      size[:, 0] / 2)
            mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >=
                      size[:, 1] / 2)
            mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
            images = images * mask

        return images
Ejemplo n.º 6
0
    def forward(self,
                x,
                img,
                ws,
                force_fp32=False,
                fused_modconv=None,
                **layer_kwargs):
        misc.assert_shape(ws,
                          [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings(
            ):  # this value will be treated as a constant
                fused_modconv = (not self.training) and (
                    dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
            misc.assert_shape(x, [
                None, self.in_channels, self.resolution // 2,
                self.resolution // 2
            ])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           gain=np.sqrt(0.5),
                           **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [
                None, self.img_channels + self.segmentation_channels,
                self.resolution // 2, self.resolution // 2
            ])
            img = upfirdn2d.upsample2d(img, self.resample_filter)

        if self.is_last or self.architecture == 'skip':
            w_temp = next(w_iter)
            rgb = self.torgb(x, w_temp, fused_modconv=fused_modconv)
            rgb = rgb.to(dtype=torch.float32,
                         memory_format=torch.contiguous_format)
            segmentation = self.tosegmentation(x,
                                               w_temp,
                                               fused_modconv=fused_modconv)
            newImg = torch.cat((rgb, segmentation), dim=1)
            img = img.add_(newImg) if img is not None else newImg

            if self.is_last:
                originalSegmentation = img[:, 3:]
                maxs = torch.max(originalSegmentation, dim=1)[0].unsqueeze(1)
                afterSubtraction = originalSegmentation - maxs + self.eps
                finalArray = torch.round(
                    torch.max(self.zeroTensor, afterSubtraction) / self.eps)

                img[:, 3:] = finalArray

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img