Exemplo n.º 1
0
def calculate_gather_axes(axes, gather_axis, num_devices):
    new_axes = [
        make_axis(a.length * num_devices, a.name) if gather_axis == a else a
        for a in axes
    ]
    new_axes = make_axes(new_axes)
    return new_axes
Exemplo n.º 2
0
    def visit(self, op):
        x, y = op.args
        x_reduction_axes = op.x_reduction_axes
        y_reduction_axes = op.y_reduction_axes
        out_axes = op.axes
        if len(x_reduction_axes) == 0:
            d = make_axis(1)
            x_reduction_axes = make_axes((d, ))
            y_reduction_axes = x_reduction_axes
            x = broadcast(x, x.axes + x_reduction_axes)
            y = broadcast(y, y_reduction_axes + y.axes)

        if x.is_scalar:
            temp = x
            x = y
            y = temp
        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(x_reduction_axes) > 0:
                    out = out * x_reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, x_reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            x_rem_axes = x.axes - x_reduction_axes
            x = axes_with_order(x, x_rem_axes + x_reduction_axes)

            y_rem_axes = y.axes - y_reduction_axes
            y = axes_with_order(y, y_reduction_axes + y_rem_axes)

            x = flatten_at(x, len(x.axes) - len(x_reduction_axes))
            y = flatten_at(y, len(y_reduction_axes))

            if len(out_axes) == 0:
                out = DotOneDimensional(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotTwoByOne(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotTwoByOne(x, y, axes=x.axes[0])
            else:
                out = DotTwoDimensional(
                    x,
                    y,
                    axes=([op.x_out_axes.flatten(),
                           op.y_out_axes.flatten()]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
Exemplo n.º 3
0
    def visit(self, op):
        x, y = op.args
        reduction_axes = op.reduction_axes
        out_axes = op.axes
        if len(reduction_axes) == 0:
            # TODO: this is a weird case, should we really support it?
            d = make_axis(1)
            reduction_axes = make_axes((d, ))
            x = broadcast(x, x.axes + reduction_axes)
            y = broadcast(y, reduction_axes + y.axes)

        if x.is_scalar:
            x, y = y, x

        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(reduction_axes) > 0:
                    out = out * reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            # move reduction_axes to end
            x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes)
            # move reduction axes to front
            y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes))

            # flatten non-reduction axes together and reduction axes together
            x = flatten_at(x, len(x.axes) - len(reduction_axes))
            # flatten non-reduction axes together and reduction axes together
            y = flatten_at(y, len(reduction_axes))

            if len(out_axes) == 0:
                out = DotLowDimension(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotLowDimension(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotLowDimension(x, y, axes=x.axes[0])
            else:
                out = DotLowDimension(x,
                                      y,
                                      axes=([
                                          op.x_out_axes.flatten(True),
                                          op.y_out_axes.flatten(True)
                                      ]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
Exemplo n.º 4
0
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axis is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axis = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_node()
                update_parallel_axis(gather_send_op, self.parallel_axis)
Exemplo n.º 5
0
    def do_pass(self, ops, **kwargs):

        ops = OrderedSet(op.forwarded for op in ops)

        for op in reversed(Op.ordered_ops(ops)):
            if op.metadata.get('marker') == 'gather':
                # op is GatherRecvOp
                if self.parallel_axes is None:
                    a = op.metadata['parallel']
                    assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\
                        .format(a, len(op.from_id))
                    self.parallel_axes = make_axis(
                        name=a.name,
                        length=a.length // len(op.from_id),
                        docstring='HeTr parallel axis')
                gather_send_op = op.send_nodes[0]

                # clone nodes for each device_id
                replaced_send_ops = OrderedSet()
                new_gather_send_nodes = OrderedSet()
                for i, id in enumerate(op.from_id):
                    new_gather_send_op, new_sends, replaced_sends = clone_graph(
                        root=gather_send_op,
                        clone_id=id,
                        shared_queues_idx=i,
                        parallel_axis=self.parallel_axes,
                        num_clones=len(op.from_id))

                    new_gather_send_nodes.add(new_gather_send_op)

                    new_sends.add(new_gather_send_op)
                    for o in new_sends:
                        self.send_nodes.add(o)

                    replaced_send_ops |= replaced_sends

                op.send_nodes = new_gather_send_nodes

                replaced_send_ops.add(gather_send_op)
                for o in replaced_send_ops:
                    self.send_nodes.remove(o)