Ejemplo n.º 1
0
def broadcast_parameters(params, root_rank=0):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `model.get_params()`.

    Arguments:
        params: One of the following:
            - list of parameters to broadcast
            - dict of parameters to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run broadcasts.
    count = 0
    for _, p in params:
        broadcast_(p, root_rank, str(count))
        count += 1

    # Make sure tensors pushed to MXNet engine get processed such that all
    # workers are synced before starting training.
    for _, p in params:
        p.wait_to_read()
Ejemplo n.º 2
0
def broadcast_parameters(params, root_rank=0):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `Module.get_params()` or the
    `Block.collect_params()`.

    Arguments:
        params: One of the following:
            - dict of parameters to broadcast
            - ParameterDict to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    if size() == 1: return

    tensors = []
    names = []
    if isinstance(params, dict):
        names, tensors = zip(*params.items())
    elif isinstance(params, mx.gluon.parameter.ParameterDict):
        for name, p in sorted(params.items()):
            try:
                tensors.append(p.data())
                names.append(name)
            except mx.gluon.parameter.DeferredInitializationError:
                # Inject wrapper method with post-initialization broadcast to
                # handle parameters with deferred initialization
                new_init = _append_broadcast_init(p, root_rank)
                p._init_impl = types.MethodType(new_init, p)
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run broadcasts.
    for tensor, name in zip(tensors, names):
        broadcast_(tensor, root_rank, name=str(name))
Ejemplo n.º 3
0
def broadcast_parameters(params, root_rank=0):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `model.get_params()`.

    Arguments:
        params: One of the following:
            - list of parameters to broadcast
            - dict of parameters to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run broadcasts.
    ret_list = []
    count = 0
    for name, p in params:
        int_name = str(count)
        broadcast_(p, root_rank, int_name)
        p.wait_to_read()
        ret_list.append((name, p))
        count += 1
    params = dict(ret_list)
Ejemplo n.º 4
0
def broadcast_parameters(params, root_rank=0):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `Module.get_params()` or the
    `Block.collect_params()`.

    Arguments:
        params: One of the following:
            - dict of parameters to broadcast
            - ParameterDict to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    tensors = []
    if isinstance(params, dict):
        tensors = [p for _, p in sorted(params.items())]
    elif isinstance(params, mx.gluon.parameter.ParameterDict):
        for _, p in sorted(params.items()):
            try:
                tensors.append(p.data())
            except mx.gluon.parameter.DeferredInitializationError:
                # skip broadcasting deferred init param
                pass
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run broadcasts.
    for i, tensor in enumerate(tensors):
        broadcast_(tensor, root_rank, str(i))

    # Make sure tensors pushed to MXNet engine get processed such that all
    # workers are synced before starting training.
    for tensor in tensors:
        tensor.wait_to_read()
Ejemplo n.º 5
0
def broadcast_parameters(params, root_rank=0, prefix=None):
    """Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `Module.get_params()` or the
    `Block.collect_params()`.

    Arguments:
        params: One of the following:
            - dict of parameters to broadcast
            - ParameterDict to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
        prefix: The prefix of the parameters to broadcast.
              If multiple `broadcast_parameters` are called in the same program,
              they must be specified by different prefixes to avoid tensor name collision.
    """
    if size() == 1: return

    tensors = []
    names = []
    assert prefix is None or isinstance(prefix, str)
    prefix = prefix if prefix else ""
    try:
        from mxnet.gluon.parameter import ParameterDict
        valid_types = (dict, ParameterDict)
    except ImportError:
        valid_types = (dict, )
    if isinstance(params, valid_types):
        for name, p in sorted(params.items()):
            try:
                if isinstance(p, mx.gluon.parameter.Parameter):
                    tensors.append(p.data())
                else:
                    tensors.append(p)
                names.append(prefix + str(name))
            except mx.gluon.parameter.DeferredInitializationError:
                # Inject wrapper method with post-initialization broadcast to
                # handle parameters with deferred initialization
                # we use the key of params instead of param.name, since
                # param.name is no longer unique in MXNet 2.0
                new_init = _append_broadcast_init(p, root_rank,
                                                  prefix + str(name))
                p._init_impl = types.MethodType(new_init, p)
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run broadcasts.
    for tensor, name in zip(tensors, names):
        broadcast_(tensor, root_rank, name=name)
Ejemplo n.º 6
0
def broadcast_object(obj, root_rank=0, name=None):
    """
    Serializes and broadcasts an object from root rank to all other processes.

    Arguments:
        obj: An object capable of being serialized without losing any context.
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
        name: Optional name to use during broadcast, will default to the class
              type.
    Returns:
        The object that was broadcast from the `root_rank`.
    """
    if name is None:
        name = type(obj).__name__

    if rank() == root_rank:
        b = io.BytesIO()
        cloudpickle.dump(obj, b)
        t = mx.nd.array(bytearray(b.getvalue()), dtype='byte')
        sz = mx.nd.array([t.size], dtype='int')

        broadcast_(sz, root_rank, name + '.sz')
    else:
        sz = mx.nd.empty(shape=[1], dtype='int')
        broadcast_(sz, root_rank, name + '.sz')
        t = mx.nd.empty(shape=[sz.asscalar()], dtype='byte')

    broadcast_(t, root_rank, name + '.t')

    if rank() != root_rank:
        buf = io.BytesIO(t.asnumpy().tobytes())
        obj = cloudpickle.load(buf)

    return obj
Ejemplo n.º 7
0
 def wrapped_init_impl(self, *args, **kwargs):
     init_impl(*args, **kwargs)
     broadcast_(self.data(), root_rank=root_rank)
     self.data().wait_to_read()
Ejemplo n.º 8
0
 def wrapped_init_impl(self, *args, **kwargs):
     init_impl(*args, **kwargs)
     broadcast_(self.data(), root_rank=root_rank, name=name)