コード例 #1
0
ファイル: comm_nodes.py プロジェクト: QiJune/ngraph
def calculate_gather_axes(axes, gather_axis, num_devices):
    new_axes = [
        make_axis(a.length * num_devices, a.name) if gather_axis == a else a
        for a in axes
    ]
    new_axes = make_axes(new_axes)
    return new_axes
コード例 #2
0
ファイル: gpulayout.py プロジェクト: ugiwgh/ngraph
def gpu_constraint_factory(op, arg):
    """
    Generates a binary layout constraint given an op and an argument

    Arguments:
        op: Computation graph op which runs on the device
        arg: Argument to the op for which to generate a constraint

    Returns:
        Binary layout constraint object
    """
    if isinstance(op, AssignOp):
        return GPUAssignLayoutConstraint(op, arg)
    elif isinstance(op, UnaryElementWiseOp):
        return GPUEWLayoutConstraint(op, arg)
    elif isinstance(op, BinaryElementWiseOp):
        return GPUEWLayoutConstraint(op, arg)
    elif isinstance(op, ReductionOp):
        return GPUEWLayoutConstraint(op, arg)
    elif isinstance(op, OneHotOp):
        return GPUEWLayoutConstraint(op, arg)
    elif isinstance(op, TensorSizeOp):
        return GPUBinaryLayoutConstraint(op, arg)
    elif isinstance(op, Fill):
        return GPUBinaryLayoutConstraint(op, arg)
    elif isinstance(op, DotOp):
        return GPUDotLayoutConstraint(op, arg)
    elif isinstance(op, ConvolutionOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, bprop_conv):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, update_conv):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, DeconvolutionOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, DeconvDerivOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, PoolingOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, BpropPoolOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, (LookupTableOp, update_lut, bprop_lut)):
        return GPULutLayoutConstraint(op, arg)
    elif isinstance(op, RngOp):
        return GPUBinaryLayoutConstraint(op, arg)
    elif isinstance(op, (GPUQueueSendOp, GPUQueueRecvOp)):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, (GPUCudaScatterSendOp, GPUCudaGatherSendOp)):
        axis_least_contig = make_axes(op.metadata['parallel'])
        new_axes = axis_least_contig + (op.axes - axis_least_contig)
        return GPUFixedLayoutConstraint(op, arg, new_axes)
    elif isinstance(op, (GPUCudaAllReduceOp)):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    elif isinstance(op, CTCOp):
        return GPUFixedLayoutConstraint(op, arg, arg.axes)
    else:
        raise ValueError("Layouts not implemented for op type {}".format(op))
コード例 #3
0
ファイル: comm_nodes.py プロジェクト: leonllm/ngraph
def set_parallel_axes(axes, parallel_axis):
    new_axes = []
    for axis in Axes.as_nested_list(axes):
        if axis == parallel_axis:
            axis = parallel_axis
        elif isinstance(axis, collections.Iterable):
            # flattened axis
            axis = [parallel_axis if a == parallel_axis else a for a in axis]
        new_axes.append(axis)

    return make_axes(new_axes)
コード例 #4
0
ファイル: comm_nodes.py プロジェクト: rsumner31/ngraph
 def hetr_axes(self, axes, parallel_axis):
     """
     Override hetr_axes function to ensure GatherRecvOp has the full length
     of the parallel_axis rather that parallel_axis.length//num_devices.
     """
     arg_axes = super(GatherRecvOp, self).hetr_axes(axes, parallel_axis)
     if parallel_axis in axes and \
        arg_axes.find_by_name(parallel_axis.name).lengths[0] != parallel_axis.length:
         arg_axes = make_axes(
             [parallel_axis if a == parallel_axis else a for a in arg_axes])
     return arg_axes
コード例 #5
0
ファイル: comm_nodes.py プロジェクト: rsumner31/ngraph
def set_parallel_axes(axes, parallel_axis):
    new_axes = []
    flat_names = dict()
    for i, axis in enumerate(Axes.as_nested_list(axes)):
        if axis == parallel_axis:
            axis = parallel_axis
        elif isinstance(axis, collections.Iterable):
            flat_names[i] = axes[i].name
            axis = [parallel_axis if a == parallel_axis else a for a in axis]
        new_axes.append(axis)
    new_axes = make_axes(new_axes)

    for i in flat_names:
        new_axes[i].name = flat_names[i]
    return new_axes
コード例 #6
0
ファイル: comm_nodes.py プロジェクト: rsumner31/ngraph
    def __init__(self, from_node, parallel_axis=None):
        super(SendOp,
              self).__init__(node=from_node,
                             args=tuple([from_node]),
                             axes=self.hetr_axes(from_node.axes,
                                                 parallel_axis),
                             dtype=from_node.dtype)

        # Add native/original axes to op
        # Also ensure that the native axes has the original length
        # of the parallel_axis
        self.native_axes = Axes.as_flattened_list(from_node.axes)
        if parallel_axis is not None and parallel_axis in self.native_axes:
            p_axis_idx = self.native_axes.index(parallel_axis)
            self.native_axes[p_axis_idx] = parallel_axis
        self.native_axes = make_axes(self.native_axes)