def forward(self): # compare inplace if self.op == "nan_to_num": if self.replace_inf: output = torch.nan_to_num(self.input, nan=1.0) else: output = torch.nan_to_num(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) else: if self.replace_inf: output = torch.nan_to_num_(self.input, nan=1.0) else: output = torch.nan_to_num_(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) return output
def accumulate_batch( self, batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None ) -> torch.Tensor: """Accumulate a batch of samples into the running statistics. Args: batch: tensor of shape ``(N_samples,) + self.dim``. The batch of samples to process. accumulate_by: tensor of indexes of shape ``(N_samples,)``. If provided, the nth sample will be accumulated into the ``accumulate_by[n]``th bin. If ``None`` (the default), all samples will be accumulated into the first (0th) bin. The indexes should be non-negative integer. Returns: tensor of shape ``(N_bins,) + self.output_dim`` giving the aggregated statistics *for this input batch*. Accumulated statistics up to this point can be retreived with ``current_result()``. ``N_bins`` is ``accumulate_by.max() + 1`` --- the number of bins in the batch --- and not the overall number of bins ``self.n_bins``. """ average, new_sum, N = self.batch_result(batch, accumulate_by) # assert self._state.shape == ((self._n_bins,)+self._dim) # assert self._n.shape == ((self._n_bins,)+self._dim) # do we need new bins? N_to_add = new_sum.shape[0] - self._n_bins if N_to_add > 0: # time to expand self._state = torch.cat( ( self._state, self._state.new_zeros((N_to_add,) + self._state.shape[1:]), ), dim=0, ) self._n = torch.cat( (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0 ) # assert self._state.shape == (self._n_bins + N_to_add,) + self._dim self._n_bins += N_to_add elif N_to_add < 0: new_sum = torch.cat( (new_sum, new_sum.new_zeros((-N_to_add,) + new_sum.shape[1:])), dim=0 ) N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0) self._state += (new_sum - N * self._state) / (self._n + N) self._n += N # Make div by zero 0 self._state = torch.nan_to_num_(self._state, nan=0.0) return average
def test_one_acc(dim, reduce_dims, reduction, do_accumulate_by, allclose): runstats = RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) reduce_in_dims = tuple(i + 1 for i in reduce_dims) batch = torch.randn((random.randint(3, 10), ) + runstats.dim) if do_accumulate_by: accumulate_by = torch.randint(0, random.randint(1, 5), size=(batch.shape[0], )) res = runstats.accumulate_batch(batch, accumulate_by=accumulate_by) if reduction == Reduction.RMS: batch = batch.square() outs = [] for i in range(max(accumulate_by) + 1): tmp = batch[accumulate_by == i].mean(dim=(0, ) + reduce_in_dims) torch.nan_to_num_(tmp, nan=0.0) outs.append(tmp) truth = torch.stack(outs, dim=0) assert truth.shape[1:] == tuple(d for i, d in enumerate(runstats.dim) if i not in reduce_dims) if reduction == Reduction.RMS: truth.sqrt_() else: res = runstats.accumulate_batch(batch) if reduction == Reduction.MEAN: truth = batch.mean(dim=(0, ) + reduce_in_dims) elif reduction == Reduction.RMS: truth = batch.square().mean(dim=(0, ) + reduce_in_dims).sqrt() assert allclose(truth, res) assert allclose(truth, runstats.current_result())
def add_to_loss_dict(d: dict, key: str, loss: torch.Tensor, weight=None): dk = d[key] torch.nan_to_num_(loss, 0., 0., 0.) if weight is not None: loss = loss * weight d[key] = loss if dk is None else dk + loss
def pointwise_ops(self): a = torch.randn(4) b = torch.randn(4) t = torch.tensor([-1, -2, 3], dtype=torch.int8) r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) t = torch.tensor([-1, -2, 3], dtype=torch.int8) s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) f = torch.zeros(3) g = torch.tensor([-1, 0, 1]) w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) return len( torch.abs(torch.tensor([-1, -2, 3])), torch.absolute(torch.tensor([-1, -2, 3])), torch.acos(a), torch.arccos(a), torch.acosh(a.uniform_(1.0, 2.0)), torch.add(a, 20), torch.add(a, b, out=a), b.add(a), b.add(a, out=b), b.add_(a), b.add(1), torch.add(a, torch.randn(4, 1), alpha=10), torch.addcdiv(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.addcmul(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.angle(a), torch.asin(a), torch.arcsin(a), torch.asinh(a), torch.arcsinh(a), torch.atan(a), torch.arctan(a), torch.atanh(a.uniform_(-1.0, 1.0)), torch.arctanh(a.uniform_(-1.0, 1.0)), torch.atan2(a, a), torch.bitwise_not(t), torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.ceil(a), torch.ceil(float(torch.tensor(0.5))), torch.ceil(torch.tensor(0.5).item()), torch.clamp(a, min=-0.5, max=0.5), torch.clamp(a, min=0.5), torch.clamp(a, max=0.5), torch.clip(a, min=-0.5, max=0.5), torch.conj(a), torch.copysign(a, 1), torch.copysign(a, b), torch.cos(a), torch.cosh(a), torch.deg2rad( torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])), torch.div(a, b), a.div(b), a.div(1), a.div_(b), torch.divide(a, b, rounding_mode="trunc"), torch.divide(a, b, rounding_mode="floor"), torch.digamma(torch.tensor([1.0, 0.5])), torch.erf(torch.tensor([0.0, -1.0, 10.0])), torch.erfc(torch.tensor([0.0, -1.0, 10.0])), torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), torch.exp(torch.tensor([0.0, math.log(2.0)])), torch.exp(float(torch.tensor(1))), torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), torch.expm1(torch.tensor([0.0, math.log(2.0)])), torch.fake_quantize_per_channel_affine( torch.randn(2, 2, 2), (torch.randn(2) + 1) * 0.05, torch.zeros(2), 1, 0, 255, ), torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), torch.float_power(torch.randint(10, (4, )), 2), torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), torch.floor(a), torch.floor(float(torch.tensor(1))), torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.frac(torch.tensor([1.0, 2.5, -3.2])), torch.randn(4, dtype=torch.cfloat).imag, torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), torch.lerp( torch.arange(1.0, 5.0), torch.empty(4).fill_(10), torch.full_like(torch.arange(1.0, 5.0), 0.5), ), torch.lgamma(torch.arange(0.5, 2, 0.5)), torch.log(torch.arange(5) + 10), torch.log10(torch.rand(5)), torch.log1p(torch.randn(5)), torch.log2(torch.rand(5)), torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logical_and(r, s), torch.logical_and(r.double(), s.double()), torch.logical_and(r.double(), s), torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), torch.logical_not( torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), torch.logical_not( torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), out=torch.empty(3, dtype=torch.int16), ), torch.logical_or(r, s), torch.logical_or(r.double(), s.double()), torch.logical_or(r.double(), s), torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_xor(r, s), torch.logical_xor(r.double(), s.double()), torch.logical_xor(r.double(), s), torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logit(torch.rand(5), eps=1e-6), torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), torch.i0(torch.arange(5, dtype=torch.float32)), torch.igamma(a, b), torch.igammac(a, b), torch.mul(torch.randn(3), 100), b.mul(a), b.mul(5), b.mul(a, out=b), b.mul_(a), b.mul_(5), torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), torch.nan_to_num(w), torch.nan_to_num_(w), torch.nan_to_num(w, nan=2.0), torch.nan_to_num(w, nan=2.0, posinf=1.0), torch.neg(torch.randn(5)), # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), torch.polygamma(1, torch.tensor([1.0, 0.5])), torch.polygamma(2, torch.tensor([1.0, 0.5])), torch.polygamma(3, torch.tensor([1.0, 0.5])), torch.polygamma(4, torch.tensor([1.0, 0.5])), torch.pow(a, 2), torch.pow(2, float(torch.tensor(0.5))), torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), torch.rad2deg( torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])), torch.randn(4, dtype=torch.cfloat).real, torch.reciprocal(a), torch.remainder(torch.tensor([-3.0, -2.0]), 2), torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.round(a), torch.round(torch.tensor(0.5).item()), torch.rsqrt(a), torch.sigmoid(a), torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sgn(a), torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sin(a), torch.sinc(a), torch.sinh(a), torch.sqrt(a), torch.square(a), torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), b.sub(a), b.sub_(a), b.sub(5), torch.sum(5), torch.tan(a), torch.tanh(a), torch.true_divide(a, a), torch.trunc(a), torch.trunc_(a), torch.xlogy(f, g), torch.xlogy(f, g), torch.xlogy(f, 4), torch.xlogy(2, g), )