예제 #1
0
    def _broadcast_bop(self, op_name, arr_1, arr_2) -> BlockArray:
        """We want to avoid invoking this op whenever possible; NumPy's imp is faster.

        Args:
            op_name: Name of binary operation.
            arr_1: A BlockArray.
            arr_2: A BlockArray.

        Returns:
            A BlockArray.
        """
        if arr_1.shape != arr_2.shape:
            output_grid_shape = array_utils.broadcast_shape(
                arr_1.grid.grid_shape, arr_2.grid.grid_shape)
            arr_1 = arr_1.broadcast_to(output_grid_shape)
            arr_2 = arr_2.broadcast_to(output_grid_shape)
        dtype = array_utils.get_bop_output_type(op_name, arr_1.dtype,
                                                arr_2.dtype)
        grid = ArrayGrid(arr_1.shape, arr_1.block_shape, dtype.__name__)
        rarr = BlockArray(grid, self.cm)
        for grid_entry in rarr.grid.get_entry_iterator():
            block_1: Block = arr_1.blocks[grid_entry]
            block_2: Block = arr_2.blocks[grid_entry]
            rarr.blocks[grid_entry] = block_1.bop(op_name, block_2, {})
        return rarr
예제 #2
0
 def shape(self):
     left_shape = self.left.shape()
     right_shape = self.right.shape()
     if self.op_name == "matmul" or self.op_name == "tensordot":
         return self._tdop_shape(left_shape, right_shape)
     else:
         return array_utils.broadcast_shape(left_shape, right_shape)
예제 #3
0
def test_bop_broadcasting():
    def get_array(shape):
        shape = tuple(filter(lambda x: x > 0, shape))
        if len(shape) == 0:
            return np.array(0)
        else:
            return np.empty(np.product(shape)).reshape(shape)

    perms = list(itertools.product([0, 1, 2, 3], repeat=10))
    pbar = tqdm.tqdm(total=len(perms))
    for shapes in perms:
        A: np.ndarray = get_array(shapes[:5])
        B: np.ndarray = get_array(shapes[5:])
        try:
            assert (A * B).shape == array_utils.broadcast_shape(
                A.shape, B.shape)
        except ValueError as _:
            assert not array_utils.can_broadcast_shapes(B.shape, A.shape)
            assert not array_utils.can_broadcast_shapes(A.shape, B.shape)
            with pytest.raises(ValueError):
                array_utils.broadcast_shape(A.shape, B.shape)
        pbar.update(1)
예제 #4
0
 def _mem_cost(self):
     # Computes the memory required to perform this operation.
     # We approximate by just computing the memory required to store the result.
     assert isinstance(self.left, Leaf) and isinstance(self.right, Leaf)
     lblock: Block = self.cluster_state.get_block(self.left.block_id)
     rblock: Block = self.cluster_state.get_block(self.right.block_id)
     if self.op_name == "matmul" or self.op_name == "tensordot":
         output_shape = self._tdop_shape(lblock.shape, rblock.shape)
     else:
         assert array_utils.can_broadcast_shapes(lblock.shape, rblock.shape)
         output_shape = array_utils.broadcast_shape(lblock.shape,
                                                    rblock.shape)
     return np.product(output_shape)
예제 #5
0
 def __pow__(self, other):
     other = self.other_to_ba(other)
     return self.ga_from_arr(
         self.graphs**other.graphs,
         array_utils.broadcast_shape(self.shape, other.shape))