def calculate_gather_axes(axes, gather_axis, num_devices): new_axes = [ make_axis(a.length * num_devices, a.name) if gather_axis == a else a for a in axes ] new_axes = make_axes(new_axes) return new_axes
def gpu_constraint_factory(op, arg): """ Generates a binary layout constraint given an op and an argument Arguments: op: Computation graph op which runs on the device arg: Argument to the op for which to generate a constraint Returns: Binary layout constraint object """ if isinstance(op, AssignOp): return GPUAssignLayoutConstraint(op, arg) elif isinstance(op, UnaryElementWiseOp): return GPUEWLayoutConstraint(op, arg) elif isinstance(op, BinaryElementWiseOp): return GPUEWLayoutConstraint(op, arg) elif isinstance(op, ReductionOp): return GPUEWLayoutConstraint(op, arg) elif isinstance(op, OneHotOp): return GPUEWLayoutConstraint(op, arg) elif isinstance(op, TensorSizeOp): return GPUBinaryLayoutConstraint(op, arg) elif isinstance(op, Fill): return GPUBinaryLayoutConstraint(op, arg) elif isinstance(op, DotOp): return GPUDotLayoutConstraint(op, arg) elif isinstance(op, ConvolutionOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, bprop_conv): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, update_conv): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, DeconvolutionOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, DeconvDerivOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, PoolingOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, BpropPoolOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, (LookupTableOp, update_lut, bprop_lut)): return GPULutLayoutConstraint(op, arg) elif isinstance(op, RngOp): return GPUBinaryLayoutConstraint(op, arg) elif isinstance(op, (GPUQueueSendOp, GPUQueueRecvOp)): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, (GPUCudaScatterSendOp, GPUCudaGatherSendOp)): axis_least_contig = make_axes(op.metadata['parallel']) new_axes = axis_least_contig + (op.axes - axis_least_contig) return GPUFixedLayoutConstraint(op, arg, new_axes) elif isinstance(op, (GPUCudaAllReduceOp)): return GPUFixedLayoutConstraint(op, arg, arg.axes) elif isinstance(op, CTCOp): return GPUFixedLayoutConstraint(op, arg, arg.axes) else: raise ValueError("Layouts not implemented for op type {}".format(op))
def set_parallel_axes(axes, parallel_axis): new_axes = [] for axis in Axes.as_nested_list(axes): if axis == parallel_axis: axis = parallel_axis elif isinstance(axis, collections.Iterable): # flattened axis axis = [parallel_axis if a == parallel_axis else a for a in axis] new_axes.append(axis) return make_axes(new_axes)
def hetr_axes(self, axes, parallel_axis): """ Override hetr_axes function to ensure GatherRecvOp has the full length of the parallel_axis rather that parallel_axis.length//num_devices. """ arg_axes = super(GatherRecvOp, self).hetr_axes(axes, parallel_axis) if parallel_axis in axes and \ arg_axes.find_by_name(parallel_axis.name).lengths[0] != parallel_axis.length: arg_axes = make_axes( [parallel_axis if a == parallel_axis else a for a in arg_axes]) return arg_axes
def set_parallel_axes(axes, parallel_axis): new_axes = [] flat_names = dict() for i, axis in enumerate(Axes.as_nested_list(axes)): if axis == parallel_axis: axis = parallel_axis elif isinstance(axis, collections.Iterable): flat_names[i] = axes[i].name axis = [parallel_axis if a == parallel_axis else a for a in axis] new_axes.append(axis) new_axes = make_axes(new_axes) for i in flat_names: new_axes[i].name = flat_names[i] return new_axes
def __init__(self, from_node, parallel_axis=None): super(SendOp, self).__init__(node=from_node, args=tuple([from_node]), axes=self.hetr_axes(from_node.axes, parallel_axis), dtype=from_node.dtype) # Add native/original axes to op # Also ensure that the native axes has the original length # of the parallel_axis self.native_axes = Axes.as_flattened_list(from_node.axes) if parallel_axis is not None and parallel_axis in self.native_axes: p_axis_idx = self.native_axes.index(parallel_axis) self.native_axes[p_axis_idx] = parallel_axis self.native_axes = make_axes(self.native_axes)