Exemple #1
0
 def __init__(self,
              video_shape,
              spatial_scratch=False,
              temporal_scratch=False,
              device='cuda'):
     """ Allocate required tensors on device """
     super().__init__(video_shape[1], device=device)
     self.spatial = TensorFactor((self.num_components, video_shape[0]),
                                 scratch=spatial_scratch,
                                 device=self.device,
                                 dtype=torch.float32)
     self.temporal = TensorFactor((self.num_components, video_shape[2]),
                                  scratch=temporal_scratch,
                                  device=self.device,
                                  dtype=torch.float32)
Exemple #2
0
 def __init__(self, video_shape, device='cuda'):
     """ Allocate required tensors on device """
     super().__init__(video_shape[1], device=device)
     self.mov = TensorFactor(video_shape[0],
                             scratch=False,
                             device=self.device,
                             dtype=torch.float32)
Exemple #3
0
 def __init__(self, max_num_components, region_shape, device='cuda'):
     """ Allocate required tensors on device """
     super().__init__(max_num_components, device)
     self.support = TensorFactor((max_num_components, region_shape[0]),
                                 device=device,
                                 dtype=torch.uint8)
     self.distance = TensorFactor((max_num_components, region_shape[0]),
                                  device=device,
                                  dtype=torch.float32)
     self.labels = TensorFactor((max_num_components, ),
                                device=device,
                                dtype=torch.long)
Exemple #4
0
 def __init__(self,
              max_num_components,
              video_shape,
              device='cuda'):
     """ allocate neccesary buffers & copy initialization from host """
     super().__init__(max_num_components, video_shape, device=device)
     self.distance = TensorFactor((self.num_components, video_shape[0]),
                                  scratch=True,
                                  device=self.device,
                                  dtype=torch.float32)
     self.regions = TensorFactor((self.num_components,),
                                 device=self.device,
                                 dtype=torch.long)
     self.lambdas = TensorFactor((self.num_components,),
                                 device=self.device,
                                 dtype=torch.float32)
Exemple #5
0
class LowRankVideo(VideoWrapper):
    """ Manages Tensors and multiplication with a Low-Rank Factored Video """
    def __init__(self,
                 video_shape,
                 spatial_scratch=False,
                 temporal_scratch=False,
                 device='cuda'):
        """ Allocate required tensors on device """
        super().__init__(video_shape[1], device=device)
        self.spatial = TensorFactor((self.num_components, video_shape[0]),
                                    scratch=spatial_scratch,
                                    device=self.device,
                                    dtype=torch.float32)
        self.temporal = TensorFactor((self.num_components, video_shape[2]),
                                     scratch=temporal_scratch,
                                     device=self.device,
                                     dtype=torch.float32)

    @property
    def fov(self):
        """ Shape of the field of view in pixels/voxels """
        return self.spatial.shape[1:]

    @property
    def frames(self):
        """ Length of the video in frames """
        return self.temporal.shape[1:]

    def _resize(self, num_components):
        """ Resize each factor axis inplace to represent N components """
        self.spatial.resize_axis_(0, num_components)
        self.temporal.resize_axis_(0, num_components)

    def permute(self, permutation):
        """ Reorder component in factors according to provided permutation """
        self.spatial.permute_(permutation)
        self.temporal.permute_(permutation)

    def set(self, h_spatial, h_temporal):
        """ Initializes Factors From ndarrays on host """
        self.spatial.data.copy_(h_spatial)
        self.temporal.data.copy_(h_temporal)

    def forward(self, spatial, temporal, **kwargs):
        """ Computes temporal.scratch = matmul(spatial.data, mov.t()) """
        if 'intermediate' not in kwargs:
            raise TypeError(
                "Missing 1 required keyword argument: 'intermediate'")
        torch.matmul(self.spatial.data,
                     spatial.data.t(),
                     out=kwargs['intermediate'].data)
        torch.matmul(kwargs['intermediate'].data.t(),
                     self.temporal.data,
                     out=temporal.scratch)

    def backward(self, temporal, spatial, **kwargs):
        """ Computes spatial.scratch = matmul(temporal.data, mov.t()) """
        if 'intermediate' not in kwargs:
            raise TypeError(
                "Missing 1 required keyword argument: 'intermediate'")
        torch.matmul(self.temporal.data,
                     temporal.data.t(),
                     out=kwargs['intermediate'].data)
        torch.matmul(kwargs['intermediate'].data.t(),
                     self.spatial.data,
                     out=spatial.scratch)
Exemple #6
0
class RegionMetadata(FactorCollection):
    """ Manages Metadata For Each Region In Localized Semi-NMF """
    def __init__(self, max_num_components, region_shape, device='cuda'):
        """ Allocate required tensors on device """
        super().__init__(max_num_components, device)
        self.support = TensorFactor((max_num_components, region_shape[0]),
                                    device=device,
                                    dtype=torch.uint8)
        self.distance = TensorFactor((max_num_components, region_shape[0]),
                                     device=device,
                                     dtype=torch.float32)
        self.labels = TensorFactor((max_num_components, ),
                                   device=device,
                                   dtype=torch.long)

    def _resize(self, num_components):
        """ Resize each factor axis inplace to represent N components """
        self.support.resize_axis_(0, num_components)
        self.distance.resize_axis_(0, num_components)
        self.labels.resize_axis_(0, num_components)

    def permute(self, permutation):
        """ Reorder component in factors according to provided permutation """
        self.support.permute_(permutation)
        self.distance.permute_(permutation)
        self.labels.permute_(permutation)

    def set(self, h_support, h_distance, h_labels):
        """ Initializes Factors From ndarrays on host """
        self.support.data.copy_(h_support)
        self.distance.data.copy_(h_distance)
        self.labels.data.copy_(h_labels)
Exemple #7
0
 def __init__(self,
              max_num_components,
              video_shape,
              device='cuda'):
     """ allocate neccesary buffers """
     super().__init__(max_num_components, device)
     self.spatial = TensorFactor((self.num_components, video_shape[0]),
                                 scratch=True,
                                 device=self.device,
                                 dtype=torch.float32)
     self.intermediate = TensorFactor((video_shape[1], self.num_components),
                                      device=self.device,
                                      dtype=torch.float32)
     self.temporal = TensorFactor((self.num_components, video_shape[2]),
                                  scratch=True,
                                  device=self.device,
                                  dtype=torch.float32)
     self.covariance = TensorFactor((self.num_components,)*2,
                                    device=self.device,
                                    dtype=torch.float32)
     self.scale = TensorFactor((self.num_components,),
                               scratch=True,
                               device=self.device,
                               dtype=torch.float32)
     self.index = TensorFactor((self.num_components,),
                               device=self.device,
                               dtype=torch.long)
Exemple #8
0
class BaseNMF(VideoFactorizer):
    """ Manages Factor/Buffer Lifetime & Manipulation For HALS NMF """
    def __init__(self,
                 max_num_components,
                 video_shape,
                 device='cuda'):
        """ allocate neccesary buffers """
        super().__init__(max_num_components, device)
        self.spatial = TensorFactor((self.num_components, video_shape[0]),
                                    scratch=True,
                                    device=self.device,
                                    dtype=torch.float32)
        self.intermediate = TensorFactor((video_shape[1], self.num_components),
                                         device=self.device,
                                         dtype=torch.float32)
        self.temporal = TensorFactor((self.num_components, video_shape[2]),
                                     scratch=True,
                                     device=self.device,
                                     dtype=torch.float32)
        self.covariance = TensorFactor((self.num_components,)*2,
                                       device=self.device,
                                       dtype=torch.float32)
        self.scale = TensorFactor((self.num_components,),
                                  scratch=True,
                                  device=self.device,
                                  dtype=torch.float32)
        self.index = TensorFactor((self.num_components,),
                                  device=self.device,
                                  dtype=torch.long)

    def _resize(self, num_components):
        """ Resize each factor axis inplace to represent N components """
        self.spatial.resize_axis_(0, num_components)
        self.temporal.resize_axis_(0, num_components)
        self.intermediate.resize_axis_(1, num_components)
        self.covariance.data.resize_((num_components,)*2)
        self.scale.resize_axis_(0, num_components)
        self.index.resize_axis_(0, num_components)

    def permute(self, permutation):
        """ Reorder component in factors according to provided permutation """
        self.spatial.permute_(permutation)
        self.temporal.permute_(permutation)

    def prune_empty_components(self, p_norm=2):
        """ Remove components with spatial empty factors """
        torch.norm(self.spatial.data,
                   p=p_norm,
                   dim=-1,
                   out=self.scale.data)
        self.scale.scratch.copy_(self.scale.data)
        self.scale.scratch.gt_(0)
        nnz = int(torch.sum(self.scale.scratch).item())
        if nnz < self.num_components:
            print("Removing components, num remaining: {}".format(nnz))
            self.index.data[:] = torch.argsort(
                self.scale.data, descending=True
            )
            self.permute(self.index.data)
            self.num_components = nnz

    def _spatial_precompute(self, video_wrapper, **kwargs):
        """ Precompute Sigma, Delta For Use In Spatial Update """
        # Sigma = A'A
        torch.matmul(self.temporal.data,
                     self.temporal.data.t(),
                     out=self.covariance.data)
        # Delta = C'Y'
        video_wrapper.backward(self.temporal,
                               self.spatial,
                               intermediate=self.intermediate,
                               **kwargs)

    @abstractmethod
    def _spatial_update(self, **kwargs):
        """ Use Precomputed Quantities To Update Spatial Factors """
        ...

    def update_spatial(self, video_wrapper, **kwargs):
        """ Perform the HALS update of the spatial factor """
        self._spatial_precompute(video_wrapper, **kwargs)
        self._spatial_update(**kwargs)

    def normalize_spatial(self, p_norm=float('inf')):
        """ Scale each spatial component to have unit p-norm """
        torch.norm(self.spatial.data,
                   p=p_norm,
                   dim=-1,
                   out=self.scale.data)
        torch.div(self.spatial.data,
                  self.scale.data[..., None],
                  out=self.spatial.data)
        torch.mul(self.temporal.data,
                  self.scale.data[..., None],
                  out=self.temporal.data)

    def _temporal_precompute(self, video_wrapper, **kwargs):
        """ Precompute Sigma, Delta For Use In Temporal Update """
        # Sigma = C'C
        torch.matmul(self.spatial.data,
                     self.spatial.data.t(),
                     out=self.covariance.data)
        # Delta = A'Y
        video_wrapper.forward(self.spatial,
                              self.temporal,
                              intermediate=self.intermediate,
                              **kwargs)

    @abstractmethod
    def _temporal_update(self, **kwargs):
        """ Use Precomputed Quantities To Update Temporal Factors """
        ...

    def update_temporal(self, video_wrapper, **kwargs):
        """ Perform the HALS update of the temporal factor """
        self._temporal_precompute(video_wrapper, **kwargs)
        self._temporal_update(**kwargs)

    def normalize_temporal(self, p_norm=float('inf')):
        """ Scale each temporal component to have unit p-norm """
        torch.norm(self.temporal.data,
                   p=p_norm,
                   dim=-1,
                   out=self.scale.data)
        torch.div(self.temporal.data,
                  self.scale.data[..., None],
                  out=self.temporal.data)
        torch.mul(self.spatial.data,
                  self.scale.data[..., None],
                  out=self.spatial.data)
Exemple #9
0
class LocalizedNMF(SpatialHals, TemporalHals, BaseNMF):
    """ Adds Region-Localization Factors & Updates To Hals Factorization """
    def __init__(self,
                 max_num_components,
                 video_shape,
                 device='cuda'):
        """ allocate neccesary buffers & copy initialization from host """
        super().__init__(max_num_components, video_shape, device=device)
        self.distance = TensorFactor((self.num_components, video_shape[0]),
                                     scratch=True,
                                     device=self.device,
                                     dtype=torch.float32)
        self.regions = TensorFactor((self.num_components,),
                                    device=self.device,
                                    dtype=torch.long)
        self.lambdas = TensorFactor((self.num_components,),
                                    device=self.device,
                                    dtype=torch.float32)

    def _resize(self, num_components):
        """ Resize each factor axis inplace to represent N components """
        super()._resize(num_components)
        self.distance.resize_axis_(0, num_components)
        self.regions.resize_axis_(0, num_components)
        self.lambdas.resize_axis_(0, num_components)

    def permute(self, permutation):
        """ Reorder component in factors according to provided permutation """
        super().permute(permutation)
        self.distance.permute_(permutation, scratch=True)
        self.regions.permute_(permutation)
        self.lambdas.permute_(permutation)

    def _spatial_precompute(self, video_wrapper, **kwargs):
        """ """
        super()._spatial_precompute(video_wrapper, **kwargs)
        torch.sub(self.spatial.scratch,
                  self.distance.scratch,
                  out=self.spatial.scratch)

    def set_from_regions(self, region_factorizations, region_metadata):
        """ Use Within-Region Factorizations To Init Full FOV Factors """
        ranks = [len(factorization) for factorization in region_factorizations]
        self.num_components = np.sum(ranks)
        sdx = 0
        self.spatial.data.fill_(0.0)
        for rdx, rank in enumerate(ranks):
            self.spatial.data[sdx:sdx+rank].masked_scatter_(
                region_metadata.support.data[rdx],
                region_factorizations[rdx].spatial.data
            )
            self.temporal.data[sdx:sdx+rank].copy_(
                region_factorizations[rdx].temporal.data
            )
            self.distance.data[sdx:sdx+rank].copy_(
                region_metadata.distance.data[rdx]
            )
            self.regions.data[sdx:sdx+rank].fill_(rdx)
            sdx += rank