def calculate_gather_axes(axes, gather_axis, num_devices): new_axes = [ make_axis(a.length * num_devices, a.name) if gather_axis == a else a for a in axes ] new_axes = make_axes(new_axes) return new_axes
def visit(self, op): x, y = op.args x_reduction_axes = op.x_reduction_axes y_reduction_axes = op.y_reduction_axes out_axes = op.axes if len(x_reduction_axes) == 0: d = make_axis(1) x_reduction_axes = make_axes((d, )) y_reduction_axes = x_reduction_axes x = broadcast(x, x.axes + x_reduction_axes) y = broadcast(y, y_reduction_axes + y.axes) if x.is_scalar: temp = x x = y y = temp if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(x_reduction_axes) > 0: out = out * x_reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, x_reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: x_rem_axes = x.axes - x_reduction_axes x = axes_with_order(x, x_rem_axes + x_reduction_axes) y_rem_axes = y.axes - y_reduction_axes y = axes_with_order(y, y_reduction_axes + y_rem_axes) x = flatten_at(x, len(x.axes) - len(x_reduction_axes)) y = flatten_at(y, len(y_reduction_axes)) if len(out_axes) == 0: out = DotOneDimensional(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotTwoByOne(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotTwoByOne(x, y, axes=x.axes[0]) else: out = DotTwoDimensional( x, y, axes=([op.x_out_axes.flatten(), op.y_out_axes.flatten()])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def visit(self, op): x, y = op.args reduction_axes = op.reduction_axes out_axes = op.axes if len(reduction_axes) == 0: # TODO: this is a weird case, should we really support it? d = make_axis(1) reduction_axes = make_axes((d, )) x = broadcast(x, x.axes + reduction_axes) y = broadcast(y, reduction_axes + y.axes) if x.is_scalar: x, y = y, x if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(reduction_axes) > 0: out = out * reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: # move reduction_axes to end x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes) # move reduction axes to front y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes)) # flatten non-reduction axes together and reduction axes together x = flatten_at(x, len(x.axes) - len(reduction_axes)) # flatten non-reduction axes together and reduction axes together y = flatten_at(y, len(reduction_axes)) if len(out_axes) == 0: out = DotLowDimension(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotLowDimension(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotLowDimension(x, y, axes=x.axes[0]) else: out = DotLowDimension(x, y, axes=([ op.x_out_axes.flatten(True), op.y_out_axes.flatten(True) ])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axis is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axis = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_node() update_parallel_axis(gather_send_op, self.parallel_axis)
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axes is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axes = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_nodes[0] # clone nodes for each device_id replaced_send_ops = OrderedSet() new_gather_send_nodes = OrderedSet() for i, id in enumerate(op.from_id): new_gather_send_op, new_sends, replaced_sends = clone_graph( root=gather_send_op, clone_id=id, shared_queues_idx=i, parallel_axis=self.parallel_axes, num_clones=len(op.from_id)) new_gather_send_nodes.add(new_gather_send_op) new_sends.add(new_gather_send_op) for o in new_sends: self.send_nodes.add(o) replaced_send_ops |= replaced_sends op.send_nodes = new_gather_send_nodes replaced_send_ops.add(gather_send_op) for o in replaced_send_ops: self.send_nodes.remove(o)