Ejemplo n.º 1
0
 def setups(self, dl: DataLoader):
     if self._setup:
         if not self.use_single_batch:
             o = dl.dataset.__getitem__([slice(None)])[0]
         else:
             o, *_ = dl.one_batch()
         if self.by_var and is_listy(self.by_var):
             shape = torch.mean(o, dim=self.axes, keepdim=self.axes!=()).shape
             _min = torch.zeros(*shape, device=o.device) + self.range_min
             _max = torch.zeros(*shape, device=o.device) + self.range_max
             for v in self.by_var:
                 if not is_listy(v): v = [v]
                 _min[:, v] = o[:, v].mul_min(self.axes if len(v) == 1 else self.list_axes, keepdim=self.axes!=())
                 _max[:, v] = o[:, v].mul_max(self.axes if len(v) == 1 else self.list_axes, keepdim=self.axes!=())
         else:
             _min, _max = o.mul_min(self.axes, keepdim=self.axes!=()), o.mul_max(self.axes, keepdim=self.axes!=())
         self.min, self.max = _min, _max
         if len(self.min.shape) == 0:
             pv(f'{self.__class__.__name__} min={self.min}, max={self.max}, by_sample={self.by_sample}, by_var={self.by_var}, by_step={self.by_step}\n',
                self.verbose)
         else:
             pv(f'{self.__class__.__name__} min shape={self.min.shape}, max shape={self.max.shape}, by_sample={self.by_sample}, by_var={self.by_var}, by_step={self.by_step}\n',
                self.verbose)
         self._setup = False
     elif self.by_sample: self.min, self.max = -torch.ones(1), torch.ones(1)
Ejemplo n.º 2
0
 def setups(self, dl: DataLoader):
     if self._setup:
         if not self.use_single_batch:
             o = dl.dataset.__getitem__([slice(None)])[0]
         else:
             o, *_ = dl.one_batch()
         if self.by_var and is_listy(self.by_var):
             shape = torch.mean(o, dim=self.axes, keepdim=self.axes!=()).shape
             mean = torch.zeros(*shape, device=o.device)
             std = torch.ones(*shape, device=o.device)
             for v in self.by_var:
                 if not is_listy(v): v = [v]
                 mean[:, v] = torch_nanmean(o[:, v], dim=self.axes if len(v) == 1 else self.list_axes, keepdim=True)
                 std[:, v] = torch.clamp_min(torch_nanstd(o[:, v], dim=self.axes if len(v) == 1 else self.list_axes, keepdim=True), self.eps)
         else:
             mean = torch_nanmean(o, dim=self.axes, keepdim=self.axes!=())
             std = torch.clamp_min(torch_nanstd(o, dim=self.axes, keepdim=self.axes!=()), self.eps)
         self.mean, self.std = mean, std
         if len(self.mean.shape) == 0:
             pv(f'{self.__class__.__name__} mean={self.mean}, std={self.std}, by_sample={self.by_sample}, by_var={self.by_var}, by_step={self.by_step}\n',
                self.verbose)
         else:
             pv(f'{self.__class__.__name__} mean shape={self.mean.shape}, std shape={self.std.shape}, by_sample={self.by_sample}, by_var={self.by_var}, by_step={self.by_step}\n',
                self.verbose)
         self._setup = False
     elif self.by_sample: self.mean, self.std = torch.zeros(1), torch.ones(1)
Ejemplo n.º 3
0
 def setups(self, dl: DataLoader):
     if self._setup:
         if not self.use_single_batch:
             o = dl.dataset.__getitem__([slice(None)])[0]
         else:
             o, *_ = dl.one_batch()
         min, max = get_outliers_IQR(o, self.axis)
         self.min, self.max = tensor(min), tensor(max)
         if self.axis is None: pv(f'{self.__class__.__name__} min={self.min}, max={self.max}, by_sample={self.by_sample}, by_var={self.by_var}\n',
                                  self.verbose)
         else: pv(f'{self.__class__.__name__} min={self.min.shape}, max={self.max.shape}, by_sample={self.by_sample}, by_var={self.by_var}\n',
                  self.verbose)
         self._setup = False
Ejemplo n.º 4
0
    def setups(self, dl: DataLoader):
        if self._setup:
            if not self.use_single_batch:
                o = dl.dataset.__getitem__([slice(None)])[0]
            else:
                o, *_ = dl.one_batch()

            new_o = o.permute(1,0,2).flatten(1)
            median = get_percentile(new_o, 50, axis=1)
            iqrmin, iqrmax = get_outliers_IQR(new_o, axis=1, quantile_range=self.quantile_range)
            self.median = median.unsqueeze(0)
            self.iqr = torch.clamp_min((iqrmax - iqrmin).unsqueeze(0), self.eps)

            pv(f'{self.__class__.__name__} median={self.median.shape} iqr={self.iqr.shape}', self.verbose)
            self._setup = False
        else:
            if self.median is None: self.median = torch.zeros(1, device=dl.device)
            if self.iqr is None: self.iqr = torch.ones(1, device=dl.device)