Exemple #1
0
def _allreduce(grads):
    if not isinstance(grads, (list, tuple)): grads = [grads]
    dev = MakeDevice(inputs=grads)
    mode = mpi.GetParallelMode() + '_ALLREDUCE'
    key = 'Collective/{}/{}'.format(dev, mode.lower())
    module = get_module(Collective, key, dev, mode=mode)
    return module.forward(grads)
Exemple #2
0
def _allreduce(grads):
    if not mpi.Is_Init(): return
    if not isinstance(grads, (list, tuple)): grads = [grads]
    ctx = MakeContext(inputs=grads)
    mode = mpi.GetParallelMode() + '_ALLREDUCE'
    key = 'torch/ops/collective/{}:{}/{}'.format(
        ctx[0].lower(), ctx[1], mode.lower())
    module = get_module(Collective, key, ctx, mode=mode)
    return module.forward(grads)
Exemple #3
0
def _inject_update_ops(graph_def, updater):
    """Inject the update ops GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return
    updater.register_in_workspace()

    grads, update_ops = [], []
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot

    # Build update ops according to the updater
    for e in updater._param_group:
        (param, grad), arguments = e
        if _workspace.HasTensor(grad):
            grads.append(grad)
            arguments = dict(arguments, **extra_arguments)
            update_ops.append(
                _proto_utils.MakeOperatorDef(
                    op_type=updater.type(),
                    inputs=[grad],
                    outputs=[param],
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))
        else:
            _logging.info('Skip to update Tensor({}).'.format(param))

    # Check data parallel if necessary
    if _mpi.Is_Init():
        (rank, group), arguments = _mpi.AllowParallel(), {}
        if rank != -1:
            arguments['mode'] = '%s_ALLREDUCE' % _mpi.GetParallelMode()
            arguments['root'], (arguments['comm'], arguments['group']) \
                = group[0], _mpi.CreateGroup(root=group[0], incl=group)
            update_ops.insert(
                0,
                _proto_utils.MakeOperatorDef(
                    op_type='CollectiveUpdate',
                    inputs=grads,
                    outputs=grads,
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))

    graph_def.op.extend(update_ops)
Exemple #4
0
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    updater._prefix = meta_graph.name + '_'
    extra_arguments = updater._extra_kwargs
    extra_arguments['domain'] = updater._prefix
    parallel_arguments = {}

    # wrap hyper-parameters as Tensor for CC
    for k, v in updater._hyper_params.items():
        ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for tuple in updater._tuples:
        tensors = tuple[0]
        arguments = tuple[1]
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater._type
        _, u_target.name = GetOperatorName()
        for tensor in tensors:
            u_target.tensor.append(tensor)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
Exemple #5
0
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    # use graph name if missing slot
    if updater._slot is None:
        updater._slot = meta_graph.name
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot
    parallel_arguments = {}

    updater.register_in_workspace()

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for e in updater._param_group:
        pair, arguments = e
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater.type()
        _, u_target.name = GetOperatorName()
        for t in pair:
            u_target.tensor.append(t)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
Exemple #6
0
def GraphDef_Update(graph_def, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot
    parallel_arguments = {}

    updater.register_in_workspace()

    # Check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            graph_def.arg.add().CopyFrom(MakeArgument(k, v))

    for e in updater._param_group:
        pair, arguments = e
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdaterProto()
        u_target.type = updater.type()
        u_target.name = OperatorHelper.get_name()
        u_target.tensor.extend(pair)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        graph_def.updater.extend([u_target])