Esempio n. 1
0
        def maybe_move_axis_to_last(a, axis):
            def move_axis_to_last(a, axis):
                return array_ops.transpose(
                    a,
                    array_ops.concat([
                        math_ops.range(axis),
                        math_ops.range(axis + 1, array_ops.rank(a)), [axis]
                    ],
                                     axis=0))

            return np_utils.cond(
                axis == np_utils.subtract(array_ops.rank(a), 1), lambda: a,
                lambda: move_axis_to_last(a, axis))
Esempio n. 2
0
    def f(a, b):  # pylint: disable=missing-docstring
        # We can't assign to captured variable `axisa`, so make a new variable
        if axis is None:
            axis_a = axisa
            axis_b = axisb
            axis_c = axisc
        else:
            axis_a = axis
            axis_b = axis
            axis_c = axis
        if axis_a < 0:
            axis_a = np_utils.add(axis_a, array_ops.rank(a))
        if axis_b < 0:
            axis_b = np_utils.add(axis_b, array_ops.rank(b))

        def maybe_move_axis_to_last(a, axis):
            def move_axis_to_last(a, axis):
                return array_ops.transpose(
                    a,
                    array_ops.concat([
                        math_ops.range(axis),
                        math_ops.range(axis + 1, array_ops.rank(a)), [axis]
                    ],
                                     axis=0))

            return np_utils.cond(
                axis == np_utils.subtract(array_ops.rank(a), 1), lambda: a,
                lambda: move_axis_to_last(a, axis))

        a = maybe_move_axis_to_last(a, axis_a)
        b = maybe_move_axis_to_last(b, axis_b)
        a_dim = np_utils.getitem(array_ops.shape(a), -1)
        b_dim = np_utils.getitem(array_ops.shape(b), -1)

        def maybe_pad_0(a, size_of_last_dim):
            def pad_0(a):
                return array_ops.pad(
                    a,
                    array_ops.concat([
                        array_ops.zeros([array_ops.rank(a) - 1, 2],
                                        dtypes.int32),
                        constant_op.constant([[0, 1]], dtypes.int32)
                    ],
                                     axis=0))

            return np_utils.cond(math_ops.equal(size_of_last_dim, 2),
                                 lambda: pad_0(a), lambda: a)

        a = maybe_pad_0(a, a_dim)
        b = maybe_pad_0(b, b_dim)
        c = math_ops.cross(*np_utils.tf_broadcast(a, b))
        if axis_c < 0:
            axis_c = np_utils.add(axis_c, array_ops.rank(c))

        def move_last_to_axis(a, axis):
            r = array_ops.rank(a)
            return array_ops.transpose(
                a,
                array_ops.concat([
                    math_ops.range(axis), [r - 1],
                    math_ops.range(axis, r - 1)
                ],
                                 axis=0))

        c = np_utils.cond(
            (a_dim == 2) & (b_dim == 2),
            lambda: c[..., 2],
            lambda: np_utils.cond(  # pylint: disable=g-long-lambda
                axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c,
                lambda: move_last_to_axis(c, axis_c)))
        return c