def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.get_params()`. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. count = 0 for _, p in params: broadcast_(p, root_rank, str(count)) count += 1 # Make sure tensors pushed to MXNet engine get processed such that all # workers are synced before starting training. for _, p in params: p.wait_to_read()
def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `Module.get_params()` or the `Block.collect_params()`. Arguments: params: One of the following: - dict of parameters to broadcast - ParameterDict to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if size() == 1: return tensors = [] names = [] if isinstance(params, dict): names, tensors = zip(*params.items()) elif isinstance(params, mx.gluon.parameter.ParameterDict): for name, p in sorted(params.items()): try: tensors.append(p.data()) names.append(name) except mx.gluon.parameter.DeferredInitializationError: # Inject wrapper method with post-initialization broadcast to # handle parameters with deferred initialization new_init = _append_broadcast_init(p, root_rank) p._init_impl = types.MethodType(new_init, p) else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. for tensor, name in zip(tensors, names): broadcast_(tensor, root_rank, name=str(name))
def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.get_params()`. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. ret_list = [] count = 0 for name, p in params: int_name = str(count) broadcast_(p, root_rank, int_name) p.wait_to_read() ret_list.append((name, p)) count += 1 params = dict(ret_list)
def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `Module.get_params()` or the `Block.collect_params()`. Arguments: params: One of the following: - dict of parameters to broadcast - ParameterDict to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ tensors = [] if isinstance(params, dict): tensors = [p for _, p in sorted(params.items())] elif isinstance(params, mx.gluon.parameter.ParameterDict): for _, p in sorted(params.items()): try: tensors.append(p.data()) except mx.gluon.parameter.DeferredInitializationError: # skip broadcasting deferred init param pass else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. for i, tensor in enumerate(tensors): broadcast_(tensor, root_rank, str(i)) # Make sure tensors pushed to MXNet engine get processed such that all # workers are synced before starting training. for tensor in tensors: tensor.wait_to_read()
def broadcast_parameters(params, root_rank=0, prefix=None): """Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `Module.get_params()` or the `Block.collect_params()`. Arguments: params: One of the following: - dict of parameters to broadcast - ParameterDict to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. prefix: The prefix of the parameters to broadcast. If multiple `broadcast_parameters` are called in the same program, they must be specified by different prefixes to avoid tensor name collision. """ if size() == 1: return tensors = [] names = [] assert prefix is None or isinstance(prefix, str) prefix = prefix if prefix else "" try: from mxnet.gluon.parameter import ParameterDict valid_types = (dict, ParameterDict) except ImportError: valid_types = (dict, ) if isinstance(params, valid_types): for name, p in sorted(params.items()): try: if isinstance(p, mx.gluon.parameter.Parameter): tensors.append(p.data()) else: tensors.append(p) names.append(prefix + str(name)) except mx.gluon.parameter.DeferredInitializationError: # Inject wrapper method with post-initialization broadcast to # handle parameters with deferred initialization # we use the key of params instead of param.name, since # param.name is no longer unique in MXNet 2.0 new_init = _append_broadcast_init(p, root_rank, prefix + str(name)) p._init_impl = types.MethodType(new_init, p) else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. for tensor, name in zip(tensors, names): broadcast_(tensor, root_rank, name=name)
def broadcast_object(obj, root_rank=0, name=None): """ Serializes and broadcasts an object from root rank to all other processes. Arguments: obj: An object capable of being serialized without losing any context. root_rank: The rank of the process from which parameters will be broadcasted to all other processes. name: Optional name to use during broadcast, will default to the class type. Returns: The object that was broadcast from the `root_rank`. """ if name is None: name = type(obj).__name__ if rank() == root_rank: b = io.BytesIO() cloudpickle.dump(obj, b) t = mx.nd.array(bytearray(b.getvalue()), dtype='byte') sz = mx.nd.array([t.size], dtype='int') broadcast_(sz, root_rank, name + '.sz') else: sz = mx.nd.empty(shape=[1], dtype='int') broadcast_(sz, root_rank, name + '.sz') t = mx.nd.empty(shape=[sz.asscalar()], dtype='byte') broadcast_(t, root_rank, name + '.t') if rank() != root_rank: buf = io.BytesIO(t.asnumpy().tobytes()) obj = cloudpickle.load(buf) return obj
def wrapped_init_impl(self, *args, **kwargs): init_impl(*args, **kwargs) broadcast_(self.data(), root_rank=root_rank) self.data().wait_to_read()
def wrapped_init_impl(self, *args, **kwargs): init_impl(*args, **kwargs) broadcast_(self.data(), root_rank=root_rank, name=name)