def prod(self, dim=None, keepdim=False): """ Returns the product of each row of the `input` tensor in the given dimension `dim`. If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension `dim` where it is of size 1. Otherwise, `dim` is squeezed, resulting in the output tensor having 1 fewer dimension than `input`. """ if dim is None: return self.flatten().prod(dim=0) result = self.clone() while result.size(dim) > 1: size = result.size(dim) x, y, remainder = result.split([size // 2, size // 2, size % 2], dim=dim) result = x.mul_(y) result.share = torch_cat([result.share, remainder.share], dim=dim) # Squeeze result if necessary if not keepdim: result.share = result.share.squeeze(dim) return result
def sum(self, dim=None): """Add all tensors along a given dimension using a log-reduction""" if dim is None: x = self.flatten() else: x = self.transpose(0, dim) # Add all BinarySharedTensors while x.size(0) > 1: extra = None if x.size(0) % 2 == 1: extra = x[0] x = x[1:] x0 = x[: (x.size(0) // 2)] x1 = x[(x.size(0) // 2) :] x = x0 + x1 if extra is not None: x.share = torch_cat([x.share, extra.share.unsqueeze(0)]) if dim is None: x = x.squeeze() else: x = x.transpose(0, dim).squeeze(dim) return x