示例#1
0
    def from_0to1(self, normalized: T) -> T:
        """
        Set value of this parameter using a normalized value in the range [0,1]

        Args:
          normalized (T): value within [0,1] range to convert to range defined by
          minimum and maximum
        """
        # TODO: These asserts are very slow
        # assert torch.all(0.0 <= normalized)
        # assert torch.all(normalized <= 1.0)

        if not self.symmetric:
            if self.curve != 1.0:
                normalized = torch.exp2(torch.log2(normalized) / self.curve)

            return self.minimum + (self.maximum - self.minimum) * normalized

        # Compute the curve for a symmetric curve
        dist = 2.0 * normalized - 1.0
        if self.curve != 1.0:
            normalized = torch.where(
                dist == 0.0,
                dist,
                torch.exp2(torch.log2(torch.abs(dist)) / self.curve) * torch.sign(dist),
            )

        return self.minimum + (self.maximum - self.minimum) / 2.0 * (normalized + 1.0)
示例#2
0
    def __getitem__(self, item):
        """
        img: (3, H, W)
        coord: (n_points, 2)
        img_coord: (3, n_points)
        """
        img_path = os.path.join(self.data_root, 'images',
                                'train_%d' % self.group, self.data_list[item])
        img_0 = Image.open(img_path).convert('RGB')
        img = self.transform(img_0)
        # coord = torch.randn(self.n_point, 2) / 3
        # coord = coord.clip(-1, 1)
        coord = torch.arcsin((torch.rand(self.n_point, 2) - 0.5) * 2) / 1.5
        coord = coord.clip(-1, 1)
        # sample image points from original resolution
        img_coord = F.grid_sample(
            self.tot(img_0)[None],
            coord[None, None, :, :])  # img_coord: (1, 3, 1, n_points)

        # (n_P, 2) -> (n_P, 16 * 2) -> (n_P, 16, 2)
        coord_mapped = coord.repeat(1, 16).view(self.n_point, 16, 2)
        # (n_P, 16, 2) * (1, 16, 2) -> (n_P, 16, 2)
        coord_mapped = coord_mapped * torch.exp2(
            (torch.arange(0, 16) / 2)[:, None].repeat(1, 2))[None]
        # (n_P, 32)
        coord_mapped = torch.sin(coord_mapped.view(self.n_point, 32))

        return img, coord_mapped, img_coord.view(3, self.n_point)
示例#3
0
    def __init__(self, optimizer, lr_conv, lr_line):
        super().__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(1792, 768, kernel_size=3, padding=1, groups=4),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=1, groups=3),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=1, groups=2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=0, groups=3),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 128, kernel_size=3, padding=0, groups=1),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),  # 128 * 4 * 4
        )

        self.linear_net_1 = nn.Sequential(
            nn.Linear(2048 + 32, 1024),
            nn.LeakyReLU(),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
            MemLinear(512),
        )
        self.linear_net_2 = nn.Sequential(
            nn.Linear(1024 + 32, 512),
            nn.LeakyReLU(),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            MemLinear(256),
            nn.Linear(512, 3),
            nn.Sigmoid(),
        )
        self.coord_batch = 8192
        self.coord_total = 256 * 256

        r = torch.arange(-128, 128) * 3.14159265358979323846264 / 256
        img_grid = torch.cat(torch.meshgrid([r,
                                             r])).view(2, 256,
                                                       256).repeat(16, 1, 1)
        img_grid = img_grid * torch.exp2(
            (torch.arange(0, 16) / 2)[:, None].repeat(1, 1, 2).view(32, 1, 1))
        img_grid = torch.sin(img_grid)
        self.img_grid = img_grid.view(32, 256 * 256).T.cuda()

        self.optim_conv = optimizer(self.conv_net.parameters(), lr=lr_conv)
        lin_net = chain(self.linear_net_1.parameters(),
                        self.linear_net_2.parameters())
        self.optim_line = optimizer(lin_net, lr=lr_line)
示例#4
0
    def test_from_0to1(self):
        # Test with linear range
        param_range = ModuleParameterRange(0.0, 10.0)
        assert param_range.from_0to1(tensor(0.5)) == 5.0

        norm_params = torch.linspace(0.0, 1.0, 10)
        params = param_range.from_0to1(norm_params)
        expected = norm_params * 10.0
        assert torch.all(params.eq(expected))

        # Test with log scaling
        param_range = ModuleParameterRange(0.0, 10.0, curve=0.5)
        norm_params = torch.linspace(0.0, 1.0, 10)
        params = param_range.from_0to1(norm_params)
        expected = torch.exp2(torch.log2(norm_params) / 0.5) * 10.0
        assert torch.all(params.eq(expected))

        # Test with exponential scaling
        param_range = ModuleParameterRange(0.0, 10.0, curve=2.0)
        norm_params = torch.linspace(0.0, 1.0, 10)
        params = param_range.from_0to1(norm_params)
        expected = torch.exp2(torch.log2(norm_params) / 2.0) * 10.0
        assert torch.all(params.eq(expected))
示例#5
0
    def __init__(self):
        super().__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(1792, 768, kernel_size=3, padding=1, groups=4),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=1, groups=3),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=1, groups=2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=0, groups=3),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 128, kernel_size=3, padding=0, groups=1),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),  # 128 * 4 * 4
        )
        self.linear_net_1 = nn.Sequential(
            nn.Linear(2048 + 16, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(inplace=True),
            nn.Linear(512, 512),
        )
        self.linear_net_2 = nn.Sequential(
            nn.Linear(512 + 16, 512),
            nn.LeakyReLU(inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 3),
            nn.LeakyReLU(inplace=True),
        )
        self.coord_batch = 4096
        self.coord_total = 256 * 256

        r = torch.range(-128, 127) * 3.14159265358979323846264 / 256
        img_grid = torch.cat(torch.meshgrid([r, r])).view(2, 256,
                                                          256).repeat(8, 1, 1)
        img_grid = img_grid * torch.exp2(
            torch.range(0, 7)[:, None].repeat(1, 1, 2).view(16, 1, 1))
        img_grid = torch.sin(img_grid)
        self.img_grid = img_grid.view(16, 256 * 256).T.cuda()
示例#6
0
    def __init__(self):
        super().__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(1792, 768, kernel_size=3, padding=1, groups=4),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(768, 768, kernel_size=3, padding=1, groups=3),
            nn.LeakyReLU(inplace=True),  # 8 * 8
        )

        self.hyper_net = HyperNet()

        self.coord_batch = 256 * 64
        self.coord_total = 256 * 256

        r = torch.arange(-128, 128) / 128
        grid = torch.meshgrid([r, r])
        img_grid = torch.cat([grid[1], grid[0]]).view(2, 256,
                                                      256).repeat(16, 1, 1)
        img_grid = img_grid * torch.exp2(
            (torch.arange(0, 16) / 2)[:, None].repeat(1, 1, 2).view(32, 1, 1))
        img_grid = torch.sin(img_grid)
        # img_grid.shape = (256*256, 32)
        self.img_grid = img_grid.view(32, 256 * 256).T.cuda()
示例#7
0
    def forward(self, images, debug_percentile=None):
        assert isinstance(images, torch.Tensor) and images.ndim == 4
        batch_size, num_channels, height, width = images.shape
        device = images.device
        if debug_percentile is not None:
            debug_percentile = torch.as_tensor(debug_percentile,
                                               dtype=torch.float32,
                                               device=device)

        # -------------------------------------
        # Select parameters for pixel blitting.
        # -------------------------------------

        # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
        I_3 = torch.eye(3, device=device)
        G_inv = I_3

        # Apply x-flip with probability (xflip * strength).
        if self.xflip > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size], device=device) < self.xflip * self.p,
                i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)

        # Apply 90 degree rotations with probability (rotate90 * strength).
        if self.rotate90 > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 4)
            i = torch.where(
                torch.rand([batch_size], device=device) <
                self.rotate90 * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 4))
            G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)

        # Apply integer translation with probability (xint * strength).
        if self.xint > 0:
            t = (torch.rand([batch_size, 2], device=device) * 2 -
                 1) * self.xint_max
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xint * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(t,
                                    (debug_percentile * 2 - 1) * self.xint_max)
            G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width),
                                            torch.round(t[:, 1] * height))

        # --------------------------------------------------------
        # Select parameters for general geometric transformations.
        # --------------------------------------------------------

        # Apply isotropic scaling with probability (scale * strength).
        if self.scale > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.scale_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.scale * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.scale_std))
            G_inv = G_inv @ scale2d_inv(s, s)

        # Apply pre-rotation with probability p_rot.
        p_rot = 1 - torch.sqrt(
            (1 - self.rotate * self.p).clamp(0, 1))  # P(pre OR post) = p
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.rotate_max)
            G_inv = G_inv @ rotate2d_inv(-theta)  # Before anisotropic scaling.

        # Apply anisotropic scaling with probability (aniso * strength).
        if self.aniso > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.aniso_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.aniso * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.aniso_std))
            G_inv = G_inv @ scale2d_inv(s, 1 / s)

        # Apply post-rotation with probability p_rot.
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.zeros_like(theta)
            G_inv = G_inv @ rotate2d_inv(-theta)  # After anisotropic scaling.

        # Apply fractional translation with probability (xfrac * strength).
        if self.xfrac > 0:
            t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xfrac * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(
                    t,
                    torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
            G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height)

        # ----------------------------------
        # Execute geometric transformations.
        # ----------------------------------

        # Execute if the transform is not identity.
        if G_inv is not I_3:
            # Calculate padding.
            cx = (width - 1) / 2
            cy = (height - 1) / 2
            cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1],
                        device=device)  # [idx, xyz]
            cp = G_inv @ cp.t()  # [batch, xyz, idx]
            Hz_pad = self.Hz_geom.shape[0] // 4
            margin = cp[:, :2, :].permute(1, 0,
                                          2).flatten(1)  # [xy, batch * idx]
            margin = torch.cat([-margin,
                                margin]).max(dim=1).values  # [x0, y0, x1, y1]
            margin = margin + misc.constant(
                [Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
            margin = margin.max(misc.constant([0, 0] * 2, device=device))
            margin = margin.min(
                misc.constant([width - 1, height - 1] * 2, device=device))
            mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)

            # Pad image and adjust origin.
            images = torch.nn.functional.pad(input=images,
                                             pad=[mx0, mx1, my0, my1],
                                             mode='reflect')
            G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv

            # Upsample.
            images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
            G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(
                2, 2, device=device)
            G_inv = translate2d(-0.5, -0.5,
                                device=device) @ G_inv @ translate2d_inv(
                                    -0.5, -0.5, device=device)

            # Execute transformation.
            shape = [
                batch_size, num_channels, (height + Hz_pad * 2) * 2,
                (width + Hz_pad * 2) * 2
            ]
            G_inv = scale2d(2 / images.shape[3],
                            2 / images.shape[2],
                            device=device) @ G_inv @ scale2d_inv(
                                2 / shape[3], 2 / shape[2], device=device)
            grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :],
                                                   size=shape,
                                                   align_corners=False)
            images = grid_sample_gradfix.grid_sample(images, grid)

            # Downsample and crop.
            images = upfirdn2d.downsample2d(x=images,
                                            f=self.Hz_geom,
                                            down=2,
                                            padding=-Hz_pad * 2,
                                            flip_filter=True)

        # --------------------------------------------
        # Select parameters for color transformations.
        # --------------------------------------------

        # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
        I_4 = torch.eye(4, device=device)
        C = I_4

        # Apply brightness with probability (brightness * strength).
        if self.brightness > 0:
            b = torch.randn([batch_size], device=device) * self.brightness_std
            b = torch.where(
                torch.rand([batch_size], device=device) <
                self.brightness * self.p, b, torch.zeros_like(b))
            if debug_percentile is not None:
                b = torch.full_like(
                    b,
                    torch.erfinv(debug_percentile * 2 - 1) *
                    self.brightness_std)
            C = translate3d(b, b, b) @ C

        # Apply contrast with probability (contrast * strength).
        if self.contrast > 0:
            c = torch.exp2(
                torch.randn([batch_size], device=device) * self.contrast_std)
            c = torch.where(
                torch.rand([batch_size], device=device) <
                self.contrast * self.p, c, torch.ones_like(c))
            if debug_percentile is not None:
                c = torch.full_like(
                    c,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.contrast_std))
            C = scale3d(c, c, c) @ C

        # Apply luma flip with probability (lumaflip * strength).
        v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3),
                          device=device)  # Luma axis.
        if self.lumaflip > 0:
            i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.lumaflip * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            C = (I_4 - 2 * v.ger(v) * i) @ C  # Householder reflection.

        # Apply hue rotation with probability (hue * strength).
        if self.hue > 0 and num_channels > 1:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.hue_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < self.hue * self.p,
                theta, torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.hue_max)
            C = rotate3d(v, theta) @ C  # Rotate around v.

        # Apply saturation with probability (saturation * strength).
        if self.saturation > 0 and num_channels > 1:
            s = torch.exp2(
                torch.randn([batch_size, 1, 1], device=device) *
                self.saturation_std)
            s = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.saturation * self.p, s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.saturation_std))
            C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C

        # ------------------------------
        # Execute color transformations.
        # ------------------------------

        # Execute if the transform is not identity.
        if C is not I_4:
            images = images.reshape([batch_size, num_channels, height * width])
            if num_channels == 4:
                alpha = images[:,
                               3, :].unsqueeze(dim=1)  # [batch_size, 1, ...]
                rgb = C[:, :3, :
                        3] @ images[:, :3, :] + C[:, :3,
                                                  3:]  # [batch_size, 3, ...]
                images = torch.cat([rgb, alpha], dim=1)  # [batch_size, 4, ...]
            elif num_channels == 3:
                images = C[:, :3, :3] @ images + C[:, :3, 3:]
            elif num_channels == 1:
                C = C[:, :3, :].mean(dim=1, keepdims=True)
                images = images * C[:, :, :3].sum(dim=2,
                                                  keepdims=True) + C[:, :, 3:]
            else:
                raise ValueError(
                    'Image must be RGBA (4 channels), RGB (3 channels) or L (1 channel)'
                )
            images = images.reshape([batch_size, num_channels, height, width])

        # ----------------------
        # Image-space filtering.
        # ----------------------

        if self.imgfilter > 0:
            num_bands = self.Hz_fbank.shape[0]
            assert len(self.imgfilter_bands) == num_bands
            expected_power = misc.constant(
                np.array([10, 1, 1, 1]) / 13,
                device=device)  # Expected power spectrum (1/f).

            # Apply amplification for each band with probability (imgfilter * strength * band_strength).
            g = torch.ones([batch_size, num_bands],
                           device=device)  # Global gain vector (identity).
            for i, band_strength in enumerate(self.imgfilter_bands):
                t_i = torch.exp2(
                    torch.randn([batch_size], device=device) *
                    self.imgfilter_std)
                t_i = torch.where(
                    torch.rand([batch_size], device=device) <
                    self.imgfilter * self.p * band_strength, t_i,
                    torch.ones_like(t_i))
                if debug_percentile is not None:
                    t_i = torch.full_like(
                        t_i,
                        torch.exp2(
                            torch.erfinv(debug_percentile * 2 - 1) *
                            self.imgfilter_std)
                    ) if band_strength > 0 else torch.ones_like(t_i)
                t = torch.ones([batch_size, num_bands],
                               device=device)  # Temporary gain vector.
                t[:, i] = t_i  # Replace i'th element.
                t = t / (expected_power * t.square()).sum(
                    dim=-1, keepdims=True).sqrt()  # Normalize power.
                g = g * t  # Accumulate into global gain.

            # Construct combined amplification filter.
            Hz_prime = g @ self.Hz_fbank  # [batch, tap]
            Hz_prime = Hz_prime.unsqueeze(1).repeat(
                [1, num_channels, 1])  # [batch, channels, tap]
            Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1,
                                         -1])  # [batch * channels, 1, tap]

            # Apply filter.
            p = self.Hz_fbank.shape[1] // 2
            images = images.reshape(
                [1, batch_size * num_channels, height, width])
            images = torch.nn.functional.pad(input=images,
                                             pad=[p, p, p, p],
                                             mode='reflect')
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(2),
                                           groups=batch_size * num_channels)
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(3),
                                           groups=batch_size * num_channels)
            images = images.reshape([batch_size, num_channels, height, width])

        # ------------------------
        # Image-space corruptions.
        # ------------------------

        # Apply additive RGB noise with probability (noise * strength).
        if self.noise > 0:
            sigma = torch.randn([batch_size, 1, 1, 1],
                                device=device).abs() * self.noise_std
            sigma = torch.where(
                torch.rand([batch_size, 1, 1, 1], device=device) <
                self.noise * self.p, sigma, torch.zeros_like(sigma))
            if debug_percentile is not None:
                sigma = torch.full_like(
                    sigma,
                    torch.erfinv(debug_percentile) * self.noise_std)
            images = images + torch.randn(
                [batch_size, num_channels, height, width],
                device=device) * sigma

        # Apply cutout with probability (cutout * strength).
        if self.cutout > 0:
            size = torch.full([batch_size, 2, 1, 1, 1],
                              self.cutout_size,
                              device=device)
            size = torch.where(
                torch.rand([batch_size, 1, 1, 1, 1], device=device) <
                self.cutout * self.p, size, torch.zeros_like(size))
            center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
            if debug_percentile is not None:
                size = torch.full_like(size, self.cutout_size)
                center = torch.full_like(center, debug_percentile)
            coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
            coord_y = torch.arange(height,
                                   device=device).reshape([1, 1, -1, 1])
            mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >=
                      size[:, 0] / 2)
            mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >=
                      size[:, 1] / 2)
            mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
            images = images * mask

        return images
示例#8
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
示例#9
0
def midi_to_hz(midi: T) -> T:
    """
    Convert from midi (linear pitch) to frequency in Hz.
    """
    return 440.0 * (torch.exp2((midi - 69.0) / 12.0))
    def _compute(self, input_texts, model_id, stride=512, device=None):

        if device is not None:
            assert device in ["gpu", "cpu",
                              "cuda"], "device should be either gpu or cpu."
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        model = AutoModelForCausalLM.from_pretrained(model_id)
        model = model.to(device)

        tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>")

        encodings = tokenizer(input_texts,
                              padding=True,
                              return_tensors="pt",
                              return_special_tokens_mask=True).to(device)

        encoded_texts = encodings["input_ids"]
        special_tokens_masks = encodings["special_tokens_mask"]

        max_model_length = model.config.n_positions

        ppls = []

        for text_index in logging.tqdm(range(0, len(encoded_texts))):
            encoded_text = encoded_texts[text_index]
            special_tokens_mask = special_tokens_masks[text_index]

            encoded_text_length = len(encoded_text) - special_tokens_mask.sum()

            nlls = []

            target_index = max(1, min(stride - 1, encoded_text_length - 1))

            while target_index < encoded_text_length:
                start_index = max(0, target_index - (max_model_length - 1))

                input_ids = encoded_text[start_index:target_index + 1]

                target_ids = input_ids.clone()
                target_ids[:-1] = -100

                attn_mask = torch.ones(len(input_ids)).to(device)
                attn_mask[-1] = 0

                with torch.no_grad():
                    outputs = model(input_ids,
                                    labels=target_ids,
                                    attention_mask=attn_mask)
                    neg_log_likelihood = outputs[0]

                nlls.append(neg_log_likelihood)

                target_index += stride

            if len(nlls) > 0:
                ppls.append(torch.exp2(torch.mean(torch.stack(nlls))))

        ppl = torch.mean(torch.stack(ppls))

        return {"perplexity": float(ppl)}
示例#11
0
torch.digamma(torch.tensor([1, 0.5]))

# erf
torch.erf(torch.tensor([0, -1., 10.]))

# erfc
torch.erfc(torch.tensor([0, -1., 10.]))

# erfinv
torch.erfinv(torch.tensor([0, 0.5, -1.]))

# exp
torch.exp(torch.tensor([0, math.log(2.)]))

# exp2
torch.exp2(torch.tensor([0, math.log2(2.), 3, 4]))

# expm1
torch.expm1(torch.tensor([0, math.log(2.)]))

# fake_quantize_per_channel_affine
x = torch.randn(2, 2, 2)
scales = (torch.randn(2) + 1) * 0.05
zero_points = torch.zeros(2).to(torch.long)
torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255)

# fake_quantize_per_tensor_affine
torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255)

# float_power
torch.float_power(torch.randint(10, (4, )), 2)
示例#12
0
    def _compute(self,
                 input_texts,
                 model_id,
                 batch_size: int = 16,
                 add_start_token: bool = True,
                 device=None):

        if device is not None:
            assert device in ["gpu", "cpu",
                              "cuda"], "device should be either gpu or cpu."
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        model = AutoModelForCausalLM.from_pretrained(model_id)
        model = model.to(device)

        tokenizer = AutoTokenizer.from_pretrained(model_id)

        # if batch_size > 1 (which generally leads to padding being required), and
        # if there is not an already assigned pad_token, assign an existing
        # special token to also be the padding token
        if tokenizer.pad_token is None and batch_size > 1:
            existing_special_tokens = list(
                tokenizer.special_tokens_map_extended.values())
            # check that the model already has at least one special token defined
            assert (
                len(existing_special_tokens) > 0
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
            # assign one of the special tokens to also be the pad token
            tokenizer.add_special_tokens(
                {"pad_token": existing_special_tokens[0]})

        if add_start_token:
            # leave room for <BOS> token to be added:
            assert (
                tokenizer.bos_token is not None
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
            max_tokenized_len = model.config.max_length - 1
        else:
            max_tokenized_len = model.config.max_length

        encodings = tokenizer(
            input_texts,
            add_special_tokens=False,
            padding=True,
            truncation=True,
            max_length=max_tokenized_len,
            return_tensors="pt",
            return_attention_mask=True,
        ).to(device)

        encoded_texts = encodings["input_ids"]
        attn_masks = encodings["attention_mask"]

        # check that each input is long enough:
        if add_start_token:
            assert torch.all(torch.ge(
                attn_masks.sum(1),
                1)), "Each input text must be at least one token long."
        else:
            assert torch.all(
                torch.ge(attn_masks.sum(1), 2)
            ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

        ppls = []
        loss_fct = CrossEntropyLoss(reduction="none")

        for start_index in logging.tqdm(
                range(0, len(encoded_texts), batch_size)):
            end_index = min(start_index + batch_size, len(encoded_texts))
            encoded_batch = encoded_texts[start_index:end_index]
            attn_mask = attn_masks[start_index:end_index]

            if add_start_token:
                bos_tokens_tensor = torch.tensor(
                    [[tokenizer.bos_token_id]] *
                    encoded_batch.size(dim=0)).to(device)
                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch],
                                          dim=1)
                attn_mask = torch.cat([
                    torch.ones(bos_tokens_tensor.size(),
                               dtype=torch.int64).to(device), attn_mask
                ],
                                      dim=1)

            labels = encoded_batch

            with torch.no_grad():
                out_logits = model(encoded_batch,
                                   attention_mask=attn_mask).logits

            shift_logits = out_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

            perplexity_batch = torch.exp2(
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) *
                 shift_attention_mask_batch).sum(1) /
                shift_attention_mask_batch.sum(1))

            ppls += perplexity_batch.tolist()

        return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}