Beispiel #1
0
    def visit(self, op):
        x, y = op.args
        x_reduction_axes = op.x_reduction_axes
        y_reduction_axes = op.y_reduction_axes
        out_axes = op.axes
        if len(x_reduction_axes) == 0:
            d = make_axis(1)
            x_reduction_axes = make_axes((d, ))
            y_reduction_axes = x_reduction_axes
            x = broadcast(x, x.axes + x_reduction_axes)
            y = broadcast(y, y_reduction_axes + y.axes)

        if x.is_scalar:
            temp = x
            x = y
            y = temp
        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(x_reduction_axes) > 0:
                    out = out * x_reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, x_reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            x_rem_axes = x.axes - x_reduction_axes
            x = axes_with_order(x, x_rem_axes + x_reduction_axes)

            y_rem_axes = y.axes - y_reduction_axes
            y = axes_with_order(y, y_reduction_axes + y_rem_axes)

            x = flatten_at(x, len(x.axes) - len(x_reduction_axes))
            y = flatten_at(y, len(y_reduction_axes))

            if len(out_axes) == 0:
                out = DotOneDimensional(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotTwoByOne(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotTwoByOne(x, y, axes=x.axes[0])
            else:
                out = DotTwoDimensional(
                    x,
                    y,
                    axes=([op.x_out_axes.flatten(),
                           op.y_out_axes.flatten()]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
Beispiel #2
0
    def visit(self, op):
        x, y = op.args
        reduction_axes = op.reduction_axes
        out_axes = op.axes
        if len(reduction_axes) == 0:
            # TODO: this is a weird case, should we really support it?
            d = make_axis(1)
            reduction_axes = make_axes((d, ))
            x = broadcast(x, x.axes + reduction_axes)
            y = broadcast(y, reduction_axes + y.axes)

        if x.is_scalar:
            x, y = y, x

        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(reduction_axes) > 0:
                    out = out * reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            # move reduction_axes to end
            x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes)
            # move reduction axes to front
            y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes))

            # flatten non-reduction axes together and reduction axes together
            x = flatten_at(x, len(x.axes) - len(reduction_axes))
            # flatten non-reduction axes together and reduction axes together
            y = flatten_at(y, len(reduction_axes))

            if len(out_axes) == 0:
                out = DotLowDimension(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotLowDimension(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotLowDimension(x, y, axes=x.axes[0])
            else:
                out = DotLowDimension(x,
                                      y,
                                      axes=([
                                          op.x_out_axes.flatten(True),
                                          op.y_out_axes.flatten(True)
                                      ]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
Beispiel #3
0
    def visit(self, op, arg):
        p_axis = arg.axes.find_by_name(op.metadata['parallel'].name)
        assert len(p_axis) > 0, "Invalid to scatter a scalar"
        if arg.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, p_axis + (arg.axes - p_axis))
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            new_op = op.copy_with_new_args([arg])
            self.replace_op(op, new_op)
Beispiel #4
0
    def visit(self, op, arg):
        p_axis = arg.axes.find_by_name(
            op.send_node().metadata['parallel'].name)
        if len(p_axis) == 0:
            pass

        elif arg.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, p_axis + (arg.axes - p_axis))
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            new_op = op.copy_with_new_args([arg])
            self.replace_op(op, new_op)
Beispiel #5
0
    def visit(self, op, arg):
        if 'parallel' not in arg.metadata:
            return

        p_axis = op.axes.find_by_name(arg.metadata['parallel'].name)
        if len(p_axis) == 0:
            self.replace_op(op, arg)

        elif op.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, op.axes)
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            self.replace_op(op, arg)