Exemplo n.º 1
0
    def exec(self):
        x = self.inputs["x"]

        y_shape_dict = AxisKeyDict()
        for axis, index in self.indices.items():
            if isinstance(index, slice):
                index = normalize_slice(index, x.shape_dict[axis])
                y_shape_dict[axis] = ((abs(index.stop - index.start) - 1) // abs(index.step)) + 1

            elif isinstance(index, int):
                pass  # Remove axis

            elif index is None:
                y_shape_dict[axis] = 1  # Insert axis

        y = Variable(list(y_shape_dict.values()), Order(list(y_shape_dict.keys())))
        self.append_output("y", y)
        return y,
Exemplo n.º 2
0
    def exec(self):
        A = self.inputs["A"]
        B = self.inputs["B"]
        c_shape_dict = AxisKeyDict()

        for axis in A.order.axes:
            if axis not in self.axes[0]:
                c_shape_dict[axis] = A.shape_dict[axis]

        for axis in B.order.axes:
            if axis not in self.axes[1]:
                c_shape_dict[axis] = B.shape_dict[axis]

        C = Variable(list(c_shape_dict.values()), Order(list(c_shape_dict.keys())))
        self.append_output("C", C)
        for axis in C.order.axes:
            self.attributes.add(Tensorwise(self, axis=axis))
        return C,
Exemplo n.º 3
0
    def __call__(self, x: Variable):
        # assert index is valid
        for axis, index in self.indices.items():
            if axis in x.order.axes:
                if isinstance(index, slice):
                    index = normalize_slice(index, x.shape_dict[axis])

                    valid_start = -x.shape_dict[
                        axis] <= index.start <= x.shape_dict[axis]
                    valid_stop = -x.shape_dict[
                        axis] <= index.stop <= x.shape_dict[axis]
                    if not valid_start or not valid_stop:
                        raise ValueError(f"""
[Slice] Index {index} in {axis} is out of range:
    (x.order) = {x.order}
    (x.shape) = {x.shape}
    (indices) = {self.indices}
    (indices[{axis.name}]) = {index}
""")

                    if ((abs(index.stop - index.start) - 1) //
                            abs(index.step)) + 1 < 0:
                        raise ValueError(f"""
[Slice] Slice operator doesn't support 0-size output:
    (x.order) = {x.order}
    (x.shape) = {x.shape}
    (indices) = {self.indices}
    (indices[{axis.name}]) = {index}
""")

                elif isinstance(index, int):
                    if not -x.shape_dict[axis] <= index < x.shape_dict[axis]:
                        raise ValueError(f"""
[Slice] Index {index} in {axis} is out of range:
    (x.order) = {x.order}
    (x.shape) = {x.shape}
    (indices) = {self.indices}
    (indices[{axis.name}]) = {index}
    (valid range) = [{-x.shape_dict[axis]}, {x.shape_dict[axis]})
""")

                elif index is None:
                    raise ValueError(f"""
[Slice] Axis {axis} is already exist:
    (x.order) = {x.order}
    (x.shape) = {x.shape}
    (indices) = {self.indices}
    (indices[{axis.name}]) = {index}
""")

            else:
                if index is not None:
                    raise ValueError(f"""
[Slice] Axis {axis} is not exist in input variable. In this case, index must be "None" (=insert new axis):
    (x.order) = {x.order}
    (x.shape) = {x.shape}
    (indices) = {self.indices}
    (indices[{axis.name}]) = {index}
""")

        if all(isinstance(index, int) for index in self.indices.values()):
            raise NotImplementedError(f"""
[Slice] Accessing to one element is not supported:
    (indices) = {self.indices}
""")

        y_shape_dict = AxisKeyDict()
        for axis, index in self.indices.items():
            if isinstance(index, slice):
                index = normalize_slice(index, x.shape_dict[axis])
                y_shape_dict[axis] = (
                    (abs(index.stop - index.start) - 1) // abs(index.step)) + 1

            elif isinstance(index, int):
                pass  # Remove axis

            elif index is None:
                y_shape_dict[axis] = 1  # Insert axis

        y = Variable(list(y_shape_dict.values()),
                     Order(list(y_shape_dict.keys())))

        for axis in x.order.axes:
            if axis in self.indices:
                index = self.indices[axis]
                if isinstance(
                        index, slice
                ) and index.start is None and index.stop is None and index.step is None:
                    # This axis is not sliced.
                    self.attributes.add(Tensorwise(axis))

            else:
                # This axis is not sliced.
                self.attributes.add(Tensorwise(axis))

        self.append_input("x", x)
        self.append_output("y", y)
        return y,
Exemplo n.º 4
0
    def __call__(self, A: Variable, B: Variable):
        for axis in self.axes[0]:
            assert axis in A.order.axes, f"""
[Tensordot] Input variable "A" must have axes "{axis}":
    (op) = {self}
    (op.axes[0]) = {self.axes[0]}
    (A) = {A}"""

        for axis in A.order.axes:
            if axis not in self.axes[0]:
                assert axis in self.axes[1] or axis not in B.order.axes, f"""
[Tensordot] Axes of "A" which are not reduced must not be contained in "B":
    (op) = {self}
    (A.order.axes) = {A.order.axes}
    (B.order.axes) = {B.order.axes}
    (op.axes) = {self.axes}"""

        for axis in self.axes[1]:
            assert axis in B.order.axes, f"""
[Tensordot] Input variable "B" must have axes "{axis}":
    (op) = {self}
    (op.axes[1]) = {self.axes[1]}
    (B) = {B}"""

        for axis in B.order.axes:
            if axis not in self.axes[1]:
                assert axis in self.axes[0] or axis not in A.order.axes, f"""
[Tensordot] Axes of "B" which are not reduced must not be contained in "A":
    (op) = {self}
    (A.order.axes) = {A.order.axes}
    (B.order.axes) = {B.order.axes}
    (op.axes) = {self.axes}"""

        reduction_size_a = mul(A.shape_dict[a] for a in self.axes[0])
        reduction_size_b = mul(B.shape_dict[a] for a in self.axes[1])
        assert reduction_size_a == reduction_size_b, f"""
[Tensordot] Reduction size of "A" and "B" must be same:
    (A) = {A}
    (B) = {B}
    (axes) = {self.axes}
    (reduction size of A) = {reduction_size_a}
    (reduction size of B) = {reduction_size_b}
"""

        c_shape_dict = AxisKeyDict()

        for axis in A.order.axes:
            if axis not in self.axes[0]:
                c_shape_dict[axis] = A.shape_dict[axis]

        for axis in B.order.axes:
            if axis not in self.axes[1]:
                c_shape_dict[axis] = B.shape_dict[axis]

        C = Variable(list(c_shape_dict.values()),
                     Order(list(c_shape_dict.keys())))

        for axis in C.order.axes:
            self.attributes.add(Tensorwise(axis))

        self.append_input("A", A)
        self.append_input("B", B)
        self.append_output("C", C)
        return C,