def draw(self, model, num_samples, num_steps=80): assert num_samples % self.ngpus == 0, 'batch size must be divisible' gpus = list(range(self.ngpus)) with timer('genxs'): xs = [ torch.rand(num_samples // self.ngpus, 3, 32, 32, device=i) * 2 - 1 for i in gpus ] for x in xs: x.requires_grad_(True) with timer('replicate model'): model.eval() models = replicate(model, gpus, detach=True) model.train() energy_befores = [ m.energy(x).mean().item() for m, x in zip(models, xs) ] print('energy befores:', energy_befores) for _ in range(num_steps): preds = parallel_apply(models, xs, devices=gpus) energies = [energy_of_pred(p).sum() for p in preds] g = torch.autograd.grad(energies, xs, retain_graph=True) # print('norms:', [gg.norm(2) for gg in g]) for i in gpus: xs[i].data.add_(-1, g[i]) xs[i].data.add_(.01, torch.randn_like(xs[i])) energy_afters = [m.energy(x).mean().item() for m, x in zip(models, xs)] print('energy afters:', energy_afters) return torch.cuda.comm.gather(xs)
def __init__(self, module, conf): super(AllReduceDataParallel, self).__init__() # init the general config variables. self.graph = conf.graph self.distributed = conf.distributed self.comm_device = conf.comm_device # devices available locally (normally in terms of the current node). self.device_ids = self.graph.device assert len(self.device_ids) == len(set(self.device_ids)) self.output_device = self.device_ids[0] # put model on output device. self.module = module.cuda() if conf.graph.on_cuda else module # prepare local intra-node all-reduce objects. if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # register grad-reduction hooks self.__register_hooks()
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): r"""Evaluates module(input) in parallel across the GPUs given in device_ids. This is the functional version of the DataParallel module. Args: module: the module to evaluate in parallel inputs: inputs to the module device_ids: GPU ids on which to replicate module output_device: GPU location of the output Use -1 to indicate the CPU. (default: device_ids[0]) Returns: a Variable containing the result of module(input) located on output_device """ if not isinstance(inputs, tuple): inputs = (inputs, ) if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[:len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim)
def wrapped(self, *inputs, **module_kwargs): if (not hasattr(self, '_is_replica')) and inputs[0].is_cuda: device_count = torch.cuda.device_count() if inputs[0].shape[0] % device_count != 0: import os cuda_visible_devices = os.environ[ 'CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else '' raise ValueError( 'batch size (%d) must be divisible by the number of GPUs (%d) used\n CUDA_VISIBLE_DEVICES: %s' % (inputs[0].shape[0], device_count, cuda_visible_devices)) if device_count > 1: # modified from pytorch (torch.nn.parallel.DataParallel) device_ids = list(range(device_count)) output_device = device_ids[0] inputs, kwargs = scatter_kwargs(inputs, module_kwargs, device_ids) replicas = replicate(self, device_ids[:len(inputs)]) # add a _is_replica flag to avoid infinite loop # from recursively calling parallel_apply for replica in replicas: replica._is_replica = True outputs = parallel_apply(replicas, inputs, kwargs) return gather(outputs, output_device) return self._forward_worker(*inputs, **module_kwargs)
def wrapper(network, *inputs, **kwargs): inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) if len(device_ids) == 1: return getattr(network, func_name)(*inputs[0], **kwargs[0]) replicas = replicate(network, device_ids[:len(inputs)]) outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) return gather(outputs, output_device)
def _ddp_init_helper(self): """ Initialization helper function that does the following: (1) replicating the module from device[0] to the other devices (2) bucketing the parameters for reductions (3) resetting the bucketing states (4) registering the grad hooks (5) passing a handle of DDP to SyncBatchNorm Layer """ if self.device_ids and len(self.device_ids) > 1: # only create replicas for single-device CUDA modules # # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] self.modules_params = [ list(m.parameters()) for m in self._module_copies ] self.modules_buffers = [list(m.buffers()) for m in self._module_copies] param_list = [ list(filter(lambda p: p.requires_grad, module.parameters())) for module in self._module_copies ] # The bucket size limit is specified in the constructor. # Additionally, we allow for a single small bucket for parameters # that are defined first, such that their gradients don't spill into # a much larger bucket, adding unnecessary latency after gradient # computation finishes. Experiments showed 1MB is a reasonable value. bucket_indices = dist._compute_bucket_assignment_by_size( param_list[0], [1024 * 1024, self.bucket_bytes_cap]) # Note: reverse list of buckets because we want to approximate the # order in which their gradients are produced, and assume they # are used in the forward pass in the order they are defined. self.reducer = dist.Reducer(param_list, list(reversed(bucket_indices)), self.process_group) # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self._module_copies)
def forward(self, inputs, *targets, **kwargs): # input should be already scatterd # scattering the targets instead if not self.device_ids: return self.module(inputs, *targets, **kwargs) targets, kwargs = inputs(targets, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(inputs, *targets[0], **kwargs[0]) replicas = replicate(self.module, self.device_ids[:len(inputs)]) outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) return Reduce.apply(*outputs) / len(outputs)
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, dont_scatter=False, dont_gather=False): r"""Evaluates module(input) in parallel across the GPUs given in device_ids. This is the functional version of the DataParallel module. Args: module: the module to evaluate in parallel inputs: inputs to the module device_ids: GPU ids on which to replicate module output_device: GPU location of the output Use -1 to indicate the CPU. (default: device_ids[0]) Returns: a Variable containing the result of module(input) located on output_device """ if not isinstance(inputs, tuple): inputs = (inputs, ) #print('getting device_ids') if device_ids is None: device_ids = list(range(torch.cuda.device_count())) #print(device_ids) if output_device is None: output_device = device_ids[0] if dont_scatter == False: do_scatter_lists = isinstance(inputs[0], list) if do_scatter_lists: inputs, module_kwargs = scatter_lists(inputs, module_kwargs, device_ids) else: inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) #print('getting used device_ids') used_device_ids = device_ids[:len(inputs)] #print(used_device_ids) #print('making model replicas') replicas = replicate(module, used_device_ids) #print('applying model') outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) if dont_gather: return tuple([[out[i] for out in outputs] for i in range(len(outputs[0]))]) #print('gathering result') return gather(outputs, output_device, dim)
def my_data_parallel(module, inputs, device_ids=None, \ dim=0, module_kwargs=None): if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if len(inputs) == 1: return module(inputs[0]) #print('my data parallel, len(inputs)', len(inputs)) replicas = replicate(module, device_ids[:len(inputs)]) outputs = my_parallel_apply(replicas, inputs, module_kwargs) return outputs
def custom_data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, chunk_sizes=None): r"""Evaluates module(input) in parallel across the GPUs given in device_ids. This is the functional version of the DataParallel module. Args: module (Module): the module to evaluate in parallel inputs (Tensor): inputs to the module device_ids (list of int or torch.device): GPU ids on which to replicate module output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. (default: device_ids[0]) Returns: a Tensor containing the result of module(input) located on output_device """ if not isinstance(inputs, tuple): inputs = (inputs, ) if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) output_device = _get_device_index(output_device, True) src_device_obj = torch.device("cuda:{}".format(device_ids[0])) for t in chain(module.parameters(), module.buffers()): if t.device != src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format( src_device_obj, t.device)) inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim, chunk_sizes) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[:len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim)
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): if not isinstance(inputs, tuple): inputs = (inputs, ) if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[:len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim)
def replicate(self, module, device_ids): modules = replicate(module, device_ids) execute_replication_callbacks(modules) return modules
def __init__(self, module, device_ids=None, distributed=True, graph=None, comm_device=None, push_sum=True, verbose=True): super(SimpleGossipDataParallel, self).__init__() # whether we're using multiple agents for training self.distributed = distributed # devices available locally if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.device_ids = device_ids self.output_device = self.device_ids[0] # put model on output device self.module = module.cuda(self.output_device) # prepare local intra-node all-reduce objects if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # prepare inter-node gossip objects if self.distributed: assert dist.is_initialized() # set distributed configuration properties self.graph = graph self.push_sum = push_sum self.gossip_enable = True if comm_device is None: comm_device = torch.device('cpu') self.comm_device = comm_device # logger used to print to stdout self.logger = make_logger(dist.get_rank(), verbose) # initalize gossiper to push-sum or push-pull protocol if self.push_sum: self.gossiper = PushSum(msg=_flatten_tensors( list(self.module.parameters())), device=self.comm_device, graph=self.graph, logger=self.logger) else: self.gossiper = PushPull(msg=_flatten_tensors( list(self.module.parameters())), device=self.comm_device, graph=self.graph, logger=self.logger) else: # logger used to print to stdout self.logger = make_logger(0, verbose) # register hook for gradient reduction on all GPUs avaialable locally self.register_backward_hook(self.__make_backward_hook())
def _ddp_init_helper(self): """ Initialization helper function that does the following: (1) replicating the module from device[0] to the other devices (2) bucketing the parameters for reductions (3) resetting the bucketing states (4) registering the grad hooks (5) passing a handle of DDP to SyncBatchNorm Layer """ if len(self.device_ids) > 1: # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] self.modules_params = [ list(m.parameters()) for m in self._module_copies ] self.modules_buffers = [list(m.buffers()) for m in self._module_copies] # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_buckets = [] # Split the parameters into buckets and by types as well # We only need to bucket and reduce parameters that require grad and # this is also true for backward since only the backward hooks for # parameters that require grad will be registered with gradient # reduction functions params_to_bucket = [[] for _ in self._module_copies] for dev_idx, m in enumerate(self._module_copies): for p in m.parameters(): if p.requires_grad: params_to_bucket[dev_idx].append(p) param_buckets = [ dist._dist_bucket_tensors(dev_params_to_bucket, int(self.bucket_bytes_cap), fine_grained=False) for dev_params_to_bucket in params_to_bucket ] self.bucket_sizes = [] self.bucket_map = {} # We transpose param_buckets, so the loop is over buckets. # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): self.bucket_sizes.append(0) # Now, we transpose again, so we iterate over bucket_elems, but getting tuples # of params from each device. for param_tuple in zip(*param_buckets_tuple): if not param_tuple[0].requires_grad: continue for p in param_tuple: self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx]) self.bucket_sizes[bucket_idx] += 1 self.buckets = [[[None for _ in range(self.bucket_sizes[i])] for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # The number of params ready in each bucket self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # coalesced bucket for only device 0 self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] # We will always reduce the bucket following the reverse order # that is, alway reduces following the order of: n - 1, n - 2, ..., 0 self.next_bucket = len(self.bucket_sizes) - 1 # When all buckets are reduced, this will be set to True. This flag is # useful for sanity checks to ensure that each iteration's backward has # always reduced all buckets self.all_buckets_reduced = False self.check_previous_reduction = False self.ready_buckets_not_reduced = set() self.reduction_works = [None for _ in range(len(self.bucket_sizes))] self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] self._register_grad_hooks() # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self._module_copies)
def __init__(self, module, device_ids=None, distributed=True, master_addr=None, master_port=None, backend=None, world_size=None, rank=None, graph=None, mixing=None, comm_device=None, lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True, verbose=True): super(BilatGossipDataParallel, self).__init__() # whether we're using multiple agents for training self.distributed = distributed # devices available locally if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.output_device = device_ids[0] self.device_ids = device_ids # put model on output device self.module = module.cuda(self.output_device) # prepare local intra-node all-reduce objects if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # prepare inter-node gossip objects if self.distributed: # communicate over cpu's if not specified if comm_device is None: comm_device = torch.device('cpu') self.__cpu_comm = comm_device.type == 'cpu' # distributed backend config self.dist_config = { 'verbose': verbose, 'graph': graph, 'master_addr': master_addr, 'master_port': master_port, 'backend': backend, 'world_size': world_size, 'rank': rank, 'mixing': mixing, 'lr': lr, 'momentum': momentum, 'nesterov': nesterov, 'weight_decay': weight_decay } self.num_updates = 0 # logger used to print to stdout self.logger = make_logger(rank, verbose) # prepare parameters for gossip self.gossip_enable = True self.gossip_params = [] self.gossip_grads = [] for p in module.parameters(): cp = p.clone().detach_() cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda() cp.requires_grad = p.requires_grad self.gossip_params.append(cp) if p.requires_grad: g = cp.clone().zero_().detach_() g = g.cpu().pin_memory() if self.__cpu_comm else g.cuda() self.gossip_grads.append(g) self.gossip_queue = mp.Queue() self.gossip_lock = mp.Lock() self.gossip_enable_flag = mp.Event() self.train_write_flag = mp.Event() # signal train-proc write event self.gossip_read_flag = mp.Event() # signal gossip-proc read event self.gossip_update_flag = mp.Event( ) # signal 2 gossip-proc need update self._lr = mp.Value('f', lr, lock=self.gossip_lock) self.gossip_thread = mp.Process( target=BilatGossipDataParallel._gossip_target, args=(self.dist_config, self.gossip_enable_flag, self.train_write_flag, self.gossip_read_flag, self.gossip_update_flag, self._lr, self.gossip_lock, self.gossip_queue)) self.gossip_thread.daemon = True self.gossip_thread.name = 'Gossip-Thread' self.gossip_thread.start() # pass handle to gossip_params and gossip_grads, and put in shared # memory self.gossip_queue.put((self.gossip_params, self.gossip_grads)) else: # logger used to print to stdout self.logger = make_logger(0, verbose) # register ps/grad-reduction hooks self.__register_hooks()
def replicate(self, module, device_ids): if self.replicas is None: from torch.nn.parallel.replicate import replicate self.replicas = replicate(module, device_ids, not torch.is_grad_enabled()) return self.replicas
def __init__(self, module, device_ids=None, distributed=True, graph=None, mixing=None, comm_device=None, push_sum=True, rank=None, world_size=None, overlap=False, synch_freq=0, verbose=True): super(GossipDataParallel, self).__init__() # whether we're using multiple agents for training self.distributed = distributed # devices available locally if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.output_device = device_ids[0] self.device_ids = device_ids # put model on output device self.module = module.cuda(self.output_device) # prepare local intra-node all-reduce objects if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # prepare inter-node gossip objects if self.distributed: if world_size is None or rank is None: assert dist.is_initialized() rank = dist.get_rank() world_size = dist.get_world_size() # communicate over cpu's if not specified if comm_device is None: comm_device = torch.device('cpu') self.__cpu_comm = comm_device.type == 'cpu' # distributed backend config self.dist_config = { 'verbose': verbose, 'comm_device': comm_device, 'graph': graph, 'mixing': mixing, 'push_sum': push_sum, 'rank': rank, 'world_size': world_size } self.overlap = overlap self.synch_freq = synch_freq self.num_updates = 0 self.asynch = synch_freq > 0 # logger used to print to stdout self.logger = make_logger(rank, verbose) # push-sum weight=1.0 ==> distributed averaging self.ps_weight = 1.0 self.is_ps_numerator = False # prepare parameters for gossip self.gossip_enable = True self.gossiping = False self.params_mixed = True self.gossip_ps_factor = [None] self.gossip_ps_weight = [self.ps_weight] self.gossip_params = [] self.gossip_device_buffer = [] for p in module.parameters(): cp = p.clone().detach_() cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda() self.gossip_params.append(cp) # prepare gossip process control objects self.gossip_lock = threading.Lock() self.gossip_flag = threading.Event() self.train_flag = threading.Event() self.gossip_thread = threading.Thread( target=GossipDataParallel._gossip_target, args=(self.dist_config, self.gossip_flag, self.train_flag, self.gossip_lock, self.gossip_params, self.gossip_device_buffer, self.gossip_ps_weight, self.gossip_ps_factor)) self.gossip_thread.daemon = True self.gossip_thread.name = 'Gossip-Thread' self.gossip_thread.start() # wait for thread to complete initialization self.gossip_flag.wait() self.gossip_flag.clear() # lazy mixing avoids additional bias/de-bias steps self.lazy_mixing = not self.asynch \ and self.dist_config['mixing'].is_regular() self.lazy_ps_factor = self.gossip_ps_factor[0] self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing)) else: self.params_mixed = True # logger used to print to stdout self.logger = make_logger(0, verbose) # register ps/grad-reduction hooks self.__register_hooks()
def parallel_forward(self, dense_x, lS_o, lS_i): ### prepare model (overwrite) ### # WARNING: # of devices must be >= batch size in parallel_forward call batch_size = dense_x.size()[0] ndevices = min(self.ndevices, batch_size, len(self.emb_l)) device_ids = range(ndevices) # WARNING: must redistribute the model if mini-batch size changes(this is common # for last mini-batch, when # of elements in the dataset/batch size is not even if self.parallel_model_batch_size != batch_size: self.parallel_model_is_not_prepared = True if self.sync_dense_params or self.parallel_model_is_not_prepared: # replicate mlp (data parallelism) self.bot_l_replicas = replicate(self.bot_l, device_ids) self.top_l_replicas = replicate(self.top_l, device_ids) # distribute embeddings (model parallelism) t_list = [] for k, emb in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) emb.to(d) t_list.append(emb.to(d)) self.emb_l = nn.ModuleList(t_list) self.parallel_model_batch_size = batch_size self.parallel_model_is_not_prepared = False ### prepare input (overwrite) ### # scatter dense features (data parallelism) # print(dense_x.device) dense_x = scatter(dense_x, device_ids, dim=0) # distribute sparse features (model parallelism) if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): sys.exit( "ERROR: corrupted model input detected in parallel_forward call" ) t_list = [] i_list = [] for k, _ in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) t_list.append(lS_o[k].to(d)) i_list.append(lS_i[k].to(d)) lS_o = t_list lS_i = i_list ### compute results in parallel ### # bottom mlp # WARNING: Note that the self.bot_l is a list of bottom mlp modules # that have been replicated across devices, while dense_x is a tuple of dense # inputs that has been scattered across devices on the first (batch) dimension. # The output is a list of tensors scattered across devices according to the # distribution of dense_x. x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids) # debug prints # print(x) # embeddings ly = self.apply_emb(lS_o, lS_i, self.emb_l) # debug prints # print(ly) # butterfly shuffle (implemented inefficiently for now) # WARNING: Note that at this point we have the result of the embedding lookup # for the entire batch on each device. We would like to obtain partial results # corresponding to all embedding lookups, but part of the batch on each device. # Therefore, matching the distribution of output of bottom mlp, so that both # could be used for subsequent interactions on each device. if len(self.emb_l) != len(ly): sys.exit( "ERROR: corrupted intermediate result in parallel_forward call" ) t_list = [] for k, _ in enumerate(self.emb_l): d = torch.device("cuda:" + str(k % ndevices)) y = scatter(ly[k], device_ids, dim=0) t_list.append(y) # adjust the list to be ordered per device ly = list(map(lambda y: list(y), zip(*t_list))) # debug prints # print(ly) # interactions z = [] for k in range(ndevices): zk = self.interact_features(x[k], ly[k]) z.append(zk) # debug prints # print(z) # top mlp # WARNING: Note that the self.top_l is a list of top mlp modules that # have been replicated across devices, while z is a list of interaction results # that by construction are scattered across devices on the first (batch) dim. # The output is a list of tensors scattered across devices according to the # distribution of z. p = parallel_apply(self.top_l_replicas, z, None, device_ids) ### gather the distributed results ### p0 = gather(p, self.output_d, dim=0) # clamp output if needed if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: z0 = torch.clamp(p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) else: z0 = p0 return z0
def replicate(self, module, device_ids): return replicate(module, device_ids, not torch.is_grad_enabled())
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DistributedDataParallel, self).__init__() if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = device_ids self.output_device = output_device # Sync params and buffers # broad the model from master(node with rank 0) to all nodes # we can borrow this in our parameter server settings for p in self.module.state_dict().values(): dist.broadcast(p, 0) if len(device_ids) > 1: # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesce, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.detach_() copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] # Split parameters into buckets that will coalesce reductions # TODO: different types need different buckets t = None for p in self.module.parameters(): tp = type(p.data) if t is not None and t is not tp: raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type") t = tp self.bucket_sizes = [] self.bucket_map = {} MB = 1024 * 1024 self.broadcast_bucket_size = 10 * MB # used for param sync before forward bucket_bytes_cap = 1 * MB bucket_bytes = bucket_bytes_cap # to init the first bucket immediately for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)): if bucket_bytes >= bucket_bytes_cap: self.bucket_sizes.append(0) bucket_bytes = 0 self.bucket_sizes[-1] += 1 for p in param_tuple: self.bucket_map[p] = len(self.bucket_sizes) - 1 bucket_bytes += p.numel() * p.element_size() self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))] self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))] self.reduced = [False] * len(self.bucket_sizes) self._register_grad_hooks() self.dispatch_lock = threading.Lock() self._start_reduction_threads()
def replicate(self, module, device_ids): return replicate(module, device_ids)
def __init__(self, module, device_ids=None, distributed=True, world_size=None, rank=None, comm_device=None, verbose=True): super(AllReduceDataParallel, self).__init__() # whether we're using multiple agents for training self.distributed = distributed # devices available locally if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.output_device = device_ids[0] self.device_ids = device_ids # put model on output device self.module = module.cuda(self.output_device) # prepare local intra-node all-reduce objects if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # prepare inter-node gossip objects if self.distributed: if world_size is None or rank is None: assert dist.is_initialized() world_size = dist.get_world_size() rank = dist.get_rank() # communicate over cpu's if not specified if comm_device is None: comm_device = torch.device('cpu') self.__cpu_comm = comm_device.type == 'cpu' # distributed backend config self.dist_config = { 'verbose': verbose, 'comm_device': comm_device, 'rank': rank } # logger used to print to stdout self.logger = make_logger(rank, verbose) # prepare parameters for gossip self.ps_factor = 1. / world_size self.gossip_enable = True self.gossiping = False self.params_mixed = True self.gossip_params = [] self.gossip_device_buffer = [] for p in module.parameters(): cp = p.clone().detach_() cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda() self.gossip_params.append(cp) # prepare gossip process control objects self.gossip_flag = threading.Event() self.train_flag = threading.Event() self.gossip_thread = threading.Thread( target=AllReduceDataParallel._gossip_target, args=(self.dist_config, self.gossip_flag, self.train_flag, self.gossip_params, self.gossip_device_buffer)) self.gossip_thread.daemon = True self.gossip_thread.name = 'Gossip-Thread' self.gossip_thread.start() # wait for thread to complete initialization self.gossip_flag.wait() self.gossip_flag.clear() else: self.params_mixed = True # logger used to print to stdout self.logger = make_logger(0, verbose) # register grad-reduction hooks self.__register_hooks()
def replicate(self, flow, device_ids): return replicate(flow, device_ids)
def __init__(self, module, device_ids=None, rank=None, world_size=None, graph=None, mixing=None, comm_device=None, push_sum=True, overlap=False, synch_freq=0, verbose=False, use_streams=True, nprocs_per_node=1, local_node_group=None): super(GossipDataParallel, self).__init__() # devices available locally if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.output_device = device_ids[0] self.device_ids = device_ids self.nprocs_per_node = nprocs_per_node if world_size is None or rank is None: assert dist.is_initialized() rank = dist.get_rank() world_size = dist.get_world_size() self.process_rank = rank if self.nprocs_per_node > 1: self.local_rank = self.process_rank % self.nprocs_per_node world_size //= nprocs_per_node rank //= nprocs_per_node if local_node_group is None: for node in range(world_size): node_processes_ranks = list( range(node * self.nprocs_per_node, (node + 1) * self.nprocs_per_node)) # Process group to communicate between processes on this # machine new_local_group = create_process_group( node_processes_ranks) if self.process_rank in node_processes_ranks: self.local_node_group = new_local_group else: self.local_node_group = local_node_group else: self.local_rank = 0 # put model on output device self.module = module first_param_dtype = next(self.module.parameters()).dtype # prepare local intra-node all-reduce objects if len(self.device_ids) > 1: self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for cmodule in self._module_copies[1:]: for p, cp in zip(self.module.parameters(), cmodule.parameters()): cp.requires_grad = p.requires_grad else: self._module_copies = [self.module] # choose communication device based on backend if comm_device is None: cpu_comm = True if dist.get_backend() == 'gloo' else False comm_device = torch.device('cpu') if cpu_comm else torch.device( 'cuda') self.__cpu_comm = comm_device.type == 'cpu' if graph is None: graph = NPDDEGraph(rank, world_size, self.nprocs_per_node, self.local_rank) if mixing is None: mixing = UniformMixing(graph, comm_device) # distributed backend config self.dist_config = { 'verbose': verbose, 'comm_device': comm_device, 'graph': graph, 'mixing': mixing, 'push_sum': push_sum, 'rank': rank, 'process_rank': self.process_rank, 'world_size': world_size, 'cpu_comm': self.__cpu_comm } self.overlap = overlap self.synch_freq = synch_freq self.num_updates = 0 self.asynch = synch_freq > 0 # logger used to print to stdout self.logger = make_logger(rank, verbose) # push-sum weight=1.0 ==> distributed averaging self.ps_weight = torch.ones(1, device=comm_device).type(first_param_dtype) self.nprocs_per_node_device = torch.tensor([self.nprocs_per_node], device=comm_device, dtype=first_param_dtype) self.is_ps_numerator = False # prepare parameters for gossip self.gossip_enable = True self.gossiping = False self.params_mixed = True self.gossip_ps_factor = torch.zeros( 1, device=comm_device).type(first_param_dtype) self.gossip_ps_weight = self.ps_weight.clone() self.gossip_params = [] self.gossip_device_buffer = [] for p in module.parameters(): cp = p.clone().detach_() cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda() self.gossip_params.append(cp) self.gossip_device_buffer.append(cp) # prepare gossip process control objects self.gossip_lock = threading.Lock() self.gossip_flag = threading.Event() self.train_flag = threading.Event() if self.dist_config['comm_device'].type != 'cpu' and use_streams: self.gossip_stream = torch.cuda.Stream() else: self.gossip_stream = torch.cuda.current_stream() if self.process_rank % self.nprocs_per_node == 0: self.gossip_thread = threading.Thread( target=GossipDataParallel._gossip_target, args=(self.dist_config, self.gossip_flag, self.train_flag, self.gossip_lock, self.gossip_params, self.gossip_device_buffer, self.gossip_ps_weight, self.gossip_ps_factor, self.gossip_stream)) self.gossip_thread.daemon = True self.gossip_thread.name = 'Gossip-Thread' self.gossip_thread.start() else: self.gossip_flag.set() # wait for thread to complete initialization self.gossip_flag.wait() self.gossip_flag.clear() # lazy mixing avoids additional bias/de-bias steps self.lazy_mixing = (not self.asynch and self.dist_config['mixing'].is_regular() and not self.overlap) self.lazy_ps_factor = self.gossip_ps_factor.clone() self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing)) # register ps/grad-reduction hooks self.__register_hooks()