예제 #1
0
 def forward(self):
     # compare inplace
     if self.op == "nan_to_num":
         if self.replace_inf:
             output = torch.nan_to_num(self.input, nan=1.0)
         else:
             output = torch.nan_to_num(self.input,
                                       nan=1.0,
                                       posinf=math.inf,
                                       neginf=-math.inf)
     else:
         if self.replace_inf:
             output = torch.nan_to_num_(self.input, nan=1.0)
         else:
             output = torch.nan_to_num_(self.input,
                                        nan=1.0,
                                        posinf=math.inf,
                                        neginf=-math.inf)
     return output
예제 #2
0
    def accumulate_batch(
        self, batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Accumulate a batch of samples into the running statistics.

        Args:
            batch: tensor of shape ``(N_samples,) + self.dim``. The batch of samples to process.
            accumulate_by: tensor of indexes of shape ``(N_samples,)``.
                If provided, the nth sample will be accumulated into the ``accumulate_by[n]``th bin. If ``None`` (the default), all samples will be accumulated into the first (0th) bin.
                The indexes should be non-negative integer.

        Returns:
            tensor of shape ``(N_bins,) + self.output_dim`` giving the aggregated statistics *for this input batch*. Accumulated statistics up to this point can be retreived with ``current_result()``.

            ``N_bins`` is ``accumulate_by.max() + 1`` --- the number of bins in the batch --- and not the overall number of bins ``self.n_bins``.
        """

        average, new_sum, N = self.batch_result(batch, accumulate_by)

        # assert self._state.shape == ((self._n_bins,)+self._dim)
        # assert self._n.shape == ((self._n_bins,)+self._dim)

        # do we need new bins?
        N_to_add = new_sum.shape[0] - self._n_bins
        if N_to_add > 0:

            # time to expand
            self._state = torch.cat(
                (
                    self._state,
                    self._state.new_zeros((N_to_add,) + self._state.shape[1:]),
                ),
                dim=0,
            )
            self._n = torch.cat(
                (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0
            )

            # assert self._state.shape == (self._n_bins + N_to_add,) + self._dim
            self._n_bins += N_to_add

        elif N_to_add < 0:

            new_sum = torch.cat(
                (new_sum, new_sum.new_zeros((-N_to_add,) + new_sum.shape[1:])), dim=0
            )
            N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0)

        self._state += (new_sum - N * self._state) / (self._n + N)
        self._n += N

        # Make div by zero 0
        self._state = torch.nan_to_num_(self._state, nan=0.0)

        return average
예제 #3
0
def test_one_acc(dim, reduce_dims, reduction, do_accumulate_by, allclose):
    runstats = RunningStats(dim=dim,
                            reduction=reduction,
                            reduce_dims=reduce_dims)
    reduce_in_dims = tuple(i + 1 for i in reduce_dims)
    batch = torch.randn((random.randint(3, 10), ) + runstats.dim)
    if do_accumulate_by:
        accumulate_by = torch.randint(0,
                                      random.randint(1, 5),
                                      size=(batch.shape[0], ))
        res = runstats.accumulate_batch(batch, accumulate_by=accumulate_by)

        if reduction == Reduction.RMS:
            batch = batch.square()

        outs = []
        for i in range(max(accumulate_by) + 1):
            tmp = batch[accumulate_by == i].mean(dim=(0, ) + reduce_in_dims)
            torch.nan_to_num_(tmp, nan=0.0)
            outs.append(tmp)

        truth = torch.stack(outs, dim=0)
        assert truth.shape[1:] == tuple(d for i, d in enumerate(runstats.dim)
                                        if i not in reduce_dims)

        if reduction == Reduction.RMS:
            truth.sqrt_()
    else:
        res = runstats.accumulate_batch(batch)
        if reduction == Reduction.MEAN:
            truth = batch.mean(dim=(0, ) + reduce_in_dims)
        elif reduction == Reduction.RMS:
            truth = batch.square().mean(dim=(0, ) + reduce_in_dims).sqrt()

    assert allclose(truth, res)
    assert allclose(truth, runstats.current_result())
예제 #4
0
def add_to_loss_dict(d: dict, key: str, loss: torch.Tensor, weight=None):
    dk = d[key]
    torch.nan_to_num_(loss, 0., 0., 0.)
    if weight is not None:
        loss = loss * weight
    d[key] = loss if dk is None else dk + loss
예제 #5
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return len(
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, b, out=a),
         b.add(a),
         b.add(a, out=b),
         b.add_(a),
         b.add(1),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.ceil(float(torch.tensor(0.5))),
         torch.ceil(torch.tensor(0.5).item()),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         a.div(b),
         a.div(1),
         a.div_(b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp(float(torch.tensor(1))),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         torch.floor(float(torch.tensor(1))),
         torch.floor_divide(torch.tensor([4.0, 3.0]),
                            torch.tensor([2.0, 2.0])),
         torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         b.mul(a),
         b.mul(5),
         b.mul(a, out=b),
         b.mul_(a),
         b.mul_(5),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num_(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(2, float(torch.tensor(0.5))),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.round(torch.tensor(0.5).item()),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         b.sub(a),
         b.sub_(a),
         b.sub(5),
         torch.sum(5),
         torch.tan(a),
         torch.tanh(a),
         torch.true_divide(a, a),
         torch.trunc(a),
         torch.trunc_(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )