Esempio n. 1
0
 def predict(self, X, n_horizon):
     seq = X
     for i in range(n_horizon):
         seq_pred = self.model(seq)
         x_new = seq_pred[:, -1, :]
         seq = torch.cat((seq, torch.atleast_3d(x_new)), 1)
     return seq
    def __call__(self, images: Union[Tensor, NDArrayR], *,
                 labels: Union[Tensor, NDArrayR]) -> Tensor:
        """Apply the transformation.

        :param images:  Greyscale images to be colorized. Expected to be unnormalized (in the range [0, 255]).
        :param labels: Indexes (0-9) indicating the gaussian distribution from which to sample each image's color.
        :returns: Images converted to RGB.
        """
        if isinstance(images, np.ndarray):
            images = torch.as_tensor(images, dtype=torch.float32)
        if isinstance(labels, np.ndarray):
            labels = torch.as_tensor(labels, dtype=torch.long)
        # Add a singular channel dimension if one isn't already there.
        images = cast(Tensor, torch.atleast_3d(images))
        if images.ndim == 3:
            images = images.unsqueeze(1)
        images = images.expand(-1, 3, -1, -1)

        colors = self._sample_colors(self.palette[labels]).view(-1, 3, 1, 1)

        if self.binarize:
            images = (images > 127).float()

        if self.background:
            if self.black:
                # colorful background, black digits
                images_colorized = (1 - images) * colors
            else:
                # colorful background, white digits
                images_colorized = images + colors
        elif self.black:
            # black background, colorful digits
            images_colorized = images * colors
        else:
            # white background, colorful digits
            images_colorized = 1 - images * (1 - colors)

        if self.greyscale:
            images_colorized = images_colorized.mean(dim=1, keepdim=True)

        return images_colorized
Esempio n. 3
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )