def generate_default_dot_layout(op): """ Generates the default layout assignment for a dot operation on GPU. By default we allow the first operand to be transposed but not the second. Output layout of the Dot operation is defined by rows which are the non-reduction axes from the first operand and columns which are the non-reduction axes from the second operand. Arguments: op (DotOp): op to generate layout for Returns: GPUDotLayoutAssignment for this op """ axes_list = Axes.as_flattened_list(op.axes) rows_axis = [ axes_list.index(a) for a in Axes.as_flattened_list(op.x_out_axes) ] cols_axis = [ axes_list.index(a) for a in Axes.as_flattened_list(op.y_out_axes) ] # By default allow first argument to be transposed, but not second # TODO: this could be bad for perf some heuristic? return [ GPUDotLayoutAssignment(True, False, axes_list, [rows_axis, cols_axis]) ]
def __init__(self, op, arg): super(GPULutLayoutConstraint, self).__init__(op, arg) if len(arg.axes) == 2: self.order = [Axes.as_flattened_list(Axes(arg.axes[0])), Axes.as_flattened_list(Axes(arg.axes[1]))] else: self.order = [Axes.as_flattened_list(arg.axes)]
def __init__(self, op, arg): super(GPUDotLayoutConstraint, self).__init__(op, arg) args = list(self.op.args) self.op_axes = Axes.as_flattened_list(self.op.axes) if self.arg.forwarded is args[0].forwarded: self.operand = 'A' self.reduction_axes = Axes.as_flattened_list(self.op.reduction_axes) self.out_axes = Axes.as_flattened_list(self.op.x_out_axes) elif self.arg.forwarded is args[1].forwarded: self.operand = 'B' self.reduction_axes = Axes.as_flattened_list(self.op.reduction_axes) self.out_axes = Axes.as_flattened_list(self.op.y_out_axes) else: raise ValueError("Invalid argument for constraint")
def generate_default_layout(clss, axes, max_out_axes): """ Generates a default layout assignment for an elementwise operation Arguments: axes: List of axes in the output of the operation max_out_axes: The maximum number of strided axes supported by the kernel for this operation Return: A list containing a single layout assignment """ axes_list = Axes.as_flattened_list(axes) # Need to divide op axes into `max_out_axes` sets if len(axes_list) > max_out_axes: num_splits = max_out_axes - 1 num_axes = len(axes_list) split_points = list( reversed([(num_axes - (i + 1)) for i in range(num_splits)])) layout = split_points_to_groups(split_points, len(axes_list)) else: layout = [[i] for i in range(len(axes_list))] return [clss(axes_list, layout)]
def update_parallel_axis(root, parallel_axis): for op in Op.ordered_ops([root]): if hasattr(op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes(op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes(op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
def generate_default_onehot_layout(op): """ Generates the default layout assignment for a onehot operation on GPU. Arguments: op (OneHotOp): op to generate layout for Returns: GPULayoutAssignment for this op """ axes_list = Axes.as_flattened_list(op.axes) oh_axis = axes_list.index(op.axis) other_group = [i for i, a in enumerate(axes_list) if a is not op.axis] if oh_axis == 0: return [ GPUDotLayoutAssignment(True, False, axes_list, [[oh_axis], other_group]) ] elif oh_axis == (len(axes_list) - 1): return [ GPUDotLayoutAssignment(True, False, axes_list, [other_group, [oh_axis]]) ] else: group0 = [i for i in other_group if i < oh_axis] group1 = [i for i in other_group if i > oh_axis] return [ GPUDotLayoutAssignment(True, False, axes_list, [group0, [oh_axis], group1]) ]
def get_flattened_axes(x): """ Ordered list of axis visible to MKLDNN """ return [ axis for axis in Axes.as_flattened_list(x) if axis.name != '__NG_DEPTH' ]
def generate_comms_layout(op, max_out_axes): """ Generates layout assignment for communication operations Arguments: op (CommunicationOp): op to generate layout for max_out_axes: The maximum number of strided axes supported by the kernel for this operation Return: GPULayoutAssignment for this op """ parallel_axis = op.metadata['parallel'] axes_list = Axes.as_flattened_list(op.axes) if parallel_axis not in axes_list: return GPULayoutAssignment.generate_default_layout( op.axes, max_out_axes) parallel_axis_index = axes_list.index(parallel_axis) if len(axes_list) > max_out_axes: num_splits = max_out_axes - 1 num_axes = len(axes_list) split_points = list( reversed([(num_axes - (i + 1)) for i in range(num_splits)])) layout = split_points_to_groups(split_points, len(axes_list)) group_index = -1 # Find which group contains the parallel axis for idx, group in enumerate(layout): if parallel_axis_index in group: group_index = idx break # Move the parallel_axis group to the first position parallel_group = layout[0] if group_index > 0: parallel_group = layout.pop(group_index) layout.insert(0, parallel_group) # If parallel_axis is not the first in its group make it the first if parallel_axis_index != parallel_group[0]: parallel_group.remove(parallel_axis_index) parallel_group.insert(0, parallel_axis_index) else: layout = [[i] for i in range(len(axes_list)) if i != parallel_axis_index] layout.insert(0, [parallel_axis_index]) return [GPULayoutAssignment(axes_list, layout)]
def get_layout_transform(self, arg_layout, op_layout, arg): """ Returns a reshape view of the argument strided to match the AssignOp axes Arguments: arg_layout (GPULayoutAssignment): layout of the argument op_layout: (GPULayoutAssignment): layout required by the op arg (TensorOp): Op producing the argument A GPUIndexOp which satisfies the requirements of the op_layout """ arg_mem_order = flatten(arg_layout.axes) arg_view_axes = Axes.as_flattened_list(arg.axes) arg_axes = arg_layout.ng_axes out_groups = [[a] for a in arg_view_axes] return self.get_reshape(arg_mem_order, arg_axes, out_groups, arg)
def __init__(self, from_node, parallel_axis=None): super(SendOp, self).__init__(node=from_node, args=tuple([from_node]), axes=self.hetr_axes(from_node.axes, parallel_axis), dtype=from_node.dtype) # Add native/original axes to op # Also ensure that the native axes has the original length # of the parallel_axis self.native_axes = Axes.as_flattened_list(from_node.axes) if parallel_axis is not None and parallel_axis in self.native_axes: p_axis_idx = self.native_axes.index(parallel_axis) self.native_axes[p_axis_idx] = parallel_axis self.native_axes = make_axes(self.native_axes)
def generate_default_lut_layout(op): """ Generates the default layout assignment for a lookup table operation on GPU. Arguments: op (LookupTableOp): op to generate layout for Returns: GPULayoutAssignment for this op """ axes_list = Axes.as_flattened_list(op.axes) groups = Axes.as_nested_list(op.axes) layout = [] for group in groups: if isinstance(group, list): layout.append([axes_list.index(a) for a in group]) else: layout.append([axes_list.index(group)]) return [GPULayoutAssignment(axes_list, layout)]
def generate_ew_layouts(clss, axes, max_out_axes): """ Generates a set of possible layouts for an elementwise operation. Arguments: axes: List of axes in the output of the operation max_out_axes: The maximum number of strided axes supported by the kernel for this operation Return: A list of layout possibilities for this operation """ # Get list of individual axes axes_list = Axes.as_flattened_list(axes) # Need to divide op axes into `max_out_axes` sets if len(axes_list) > max_out_axes: groups = get_split_groups(len(axes_list), max_out_axes) num_groups = max_out_axes else: groups = [[[i] for i in range(len(axes_list))]] num_groups = len(axes_list) # Find all permutations of these axis groups permutations = enumerate_axis_orders(tuple(range(num_groups))) if permutations: # Create EW layouts layouts = [] for group in groups: for order in permutations: layout_spec = [group[i] for i in order] layouts.append(clss(axes_list, layout_spec)) else: layouts = [clss(axes_list, [])] return layouts
def __init__(self, op, arg, axes): super(GPUFixedLayoutConstraint, self).__init__(op, arg) self.order = Axes.as_flattened_list(axes)
def __init__(self, op, arg): super(GPUEWLayoutConstraint, self).__init__(op, arg) if isinstance(op, ReductionOp): self.red_axes = Axes.as_flattened_list(op.reduction_axes) else: self.red_axes = None
def get_flattened_axes(x): """ Ordered list of axis visible to MKLDNN """ return Axes.as_flattened_list(x)
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def clone_graph(root, clone_id, parallel_axis): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): # for gpu communication op buffer op.idx = int(clone_id) if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() # create new uuids for all the ops that have references to the new root for _op in Op.all_op_references([new_root]): _op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def __init__(self, op, arg): self.op = op self.arg = arg # Build mapping of arg axis position to axis position in buffer # Arg axes may be re-ordered, cast, broadcast, sliced between the original # buffer and the op being used for this constraint predecessor_op = arg while isinstance(predecessor_op, SequentialOp): predecessor_op = predecessor_op.value_tensor self.arg_axes_list = Axes.as_flattened_list(arg.axes) self.mappings = {} self.sliced_out = [] for i in range(len(self.arg_axes_list)): self.mappings[i] = i while not (predecessor_op.is_device_op or isinstance(predecessor_op, TensorValueOp)): pred_axes = Axes.as_flattened_list(predecessor_op.axes) pred_arg_axes = Axes.as_flattened_list(predecessor_op.args[0].axes) if isinstance(predecessor_op, (BroadcastOp, ExpandDims)): bcast_axes = [pred_axes.index(a) for a in pred_axes if a not in pred_arg_axes] bcast_mappings = [a for a in self.mappings if self.mappings[a] in bcast_axes] for bcast in bcast_mappings: self.mappings[bcast] = "bcast" for a, p in self.mappings.items(): if isinstance(p, int): offset = 0 for bcast_axis in bcast_axes: if p > bcast_axis: offset += 1 self.mappings[a] = p - offset for i in range(len(self.sliced_out)): if self.sliced_out[i][0] in bcast_axes: self.sliced_out[i] = (self.sliced_out[i][0], "bcast") else: new_axis_index = pred_arg_axes.index(pred_axes[self.sliced_out[i][0]]) self.sliced_out[i] = (new_axis_index, self.sliced_out[i][1]) elif isinstance(predecessor_op, TensorSliceOp): old_index = 0 new_indexes = [] for index, axis in enumerate(pred_arg_axes): if isinstance(predecessor_op.slices[index], int): self.sliced_out.append((index, predecessor_op.slices[index])) else: if predecessor_op.slices[index] != slice(None, None, None): new_index = ("slice", index, predecessor_op.slices[index]) else: new_index = index # Remap this axis for a, p in self.mappings.items(): if isinstance(p, int) and p == old_index: new_indexes.append((a, new_index)) old_index += 1 for a, p in new_indexes: self.mappings[a] = p elif isinstance(predecessor_op, (ReorderAxes, Transpose)): new_indexes = [] for a, p in self.mappings.items(): if isinstance(p, int): new_indexes.append((a, pred_arg_axes.index(pred_axes[p]))) for a, p in new_indexes: self.mappings[a] = p for i in range(len(self.sliced_out)): new_axis_index = pred_arg_axes.index(pred_axes[self.sliced_out[i][0]]) self.sliced_out[i] = (new_axis_index, self.sliced_out[i][1]) elif isinstance(predecessor_op, AxesCastOp): pass elif isinstance(predecessor_op, Flatten): pass elif isinstance(predecessor_op, Unflatten): pass else: raise RuntimeError("Confused") predecessor_op = predecessor_op.args[0] while isinstance(predecessor_op, SequentialOp): predecessor_op = predecessor_op.value_tensor