Example #1
0
 def draw(self, model, num_samples, num_steps=80):
     assert num_samples % self.ngpus == 0, 'batch size must be divisible'
     gpus = list(range(self.ngpus))
     with timer('genxs'):
         xs = [
             torch.rand(num_samples // self.ngpus, 3, 32, 32, device=i) * 2
             - 1 for i in gpus
         ]
         for x in xs:
             x.requires_grad_(True)
     with timer('replicate model'):
         model.eval()
         models = replicate(model, gpus, detach=True)
         model.train()
     energy_befores = [
         m.energy(x).mean().item() for m, x in zip(models, xs)
     ]
     print('energy befores:', energy_befores)
     for _ in range(num_steps):
         preds = parallel_apply(models, xs, devices=gpus)
         energies = [energy_of_pred(p).sum() for p in preds]
         g = torch.autograd.grad(energies, xs, retain_graph=True)
         # print('norms:', [gg.norm(2) for gg in g])
         for i in gpus:
             xs[i].data.add_(-1, g[i])
             xs[i].data.add_(.01, torch.randn_like(xs[i]))
     energy_afters = [m.energy(x).mean().item() for m, x in zip(models, xs)]
     print('energy afters:', energy_afters)
     return torch.cuda.comm.gather(xs)
    def __init__(self, module, conf):
        super(AllReduceDataParallel, self).__init__()

        # init the general config variables.
        self.graph = conf.graph
        self.distributed = conf.distributed
        self.comm_device = conf.comm_device

        # devices available locally (normally in terms of the current node).
        self.device_ids = self.graph.device
        assert len(self.device_ids) == len(set(self.device_ids))
        self.output_device = self.device_ids[0]

        # put model on output device.
        self.module = module.cuda() if conf.graph.on_cuda else module

        # prepare local intra-node all-reduce objects.
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)

            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # register grad-reduction hooks
        self.__register_hooks()
Example #3
0
def data_parallel(module,
                  inputs,
                  device_ids=None,
                  output_device=None,
                  dim=0,
                  module_kwargs=None):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
    This is the functional version of the DataParallel module.
    Args:
        module: the module to evaluate in parallel
        inputs: inputs to the module
        device_ids: GPU ids on which to replicate module
        output_device: GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Variable containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids,
                                           dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
    def wrapped(self, *inputs, **module_kwargs):
        if (not hasattr(self, '_is_replica')) and inputs[0].is_cuda:
            device_count = torch.cuda.device_count()
            if inputs[0].shape[0] % device_count != 0:
                import os
                cuda_visible_devices = os.environ[
                    'CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else ''
                raise ValueError(
                    'batch size (%d) must be divisible by the number of GPUs (%d) used\n CUDA_VISIBLE_DEVICES: %s'
                    % (inputs[0].shape[0], device_count, cuda_visible_devices))
            if device_count > 1:
                # modified from pytorch (torch.nn.parallel.DataParallel)
                device_ids = list(range(device_count))
                output_device = device_ids[0]
                inputs, kwargs = scatter_kwargs(inputs, module_kwargs,
                                                device_ids)
                replicas = replicate(self, device_ids[:len(inputs)])

                # add a _is_replica flag to avoid infinite loop
                # from recursively calling parallel_apply
                for replica in replicas:
                    replica._is_replica = True
                outputs = parallel_apply(replicas, inputs, kwargs)
                return gather(outputs, output_device)

        return self._forward_worker(*inputs, **module_kwargs)
Example #5
0
 def wrapper(network, *inputs, **kwargs):
     inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
     if len(device_ids) == 1:
         return getattr(network, func_name)(*inputs[0], **kwargs[0])
     replicas = replicate(network, device_ids[:len(inputs)])
     outputs = parallel_apply(replicas, func_name, inputs, kwargs,
                              device_ids[:len(replicas)])
     return gather(outputs, output_device)
Example #6
0
    def _ddp_init_helper(self):
        """
        Initialization helper function that does the following:

        (1) replicating the module from device[0] to the other devices
        (2) bucketing the parameters for reductions
        (3) resetting the bucketing states
        (4) registering the grad hooks
        (5) passing a handle of DDP to SyncBatchNorm Layer
        """
        if self.device_ids and len(self.device_ids) > 1:
            # only create replicas for single-device CUDA modules
            #
            # TODO: we don't need to replicate params in here. they're always going to
            # be broadcasted using larger blocks in broadcast_coalesced, so it might be
            # better to not pollute the caches with these small blocks
            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module

            for module_copy in self._module_copies[1:]:
                for param, copy_param in zip(self.module.parameters(),
                                             module_copy.parameters()):
                    copy_param.requires_grad = param.requires_grad

        else:
            self._module_copies = [self.module]

        self.modules_params = [
            list(m.parameters()) for m in self._module_copies
        ]
        self.modules_buffers = [list(m.buffers()) for m in self._module_copies]

        param_list = [
            list(filter(lambda p: p.requires_grad, module.parameters()))
            for module in self._module_copies
        ]

        # The bucket size limit is specified in the constructor.
        # Additionally, we allow for a single small bucket for parameters
        # that are defined first, such that their gradients don't spill into
        # a much larger bucket, adding unnecessary latency after gradient
        # computation finishes. Experiments showed 1MB is a reasonable value.
        bucket_indices = dist._compute_bucket_assignment_by_size(
            param_list[0], [1024 * 1024, self.bucket_bytes_cap])

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(param_list, list(reversed(bucket_indices)),
                                    self.process_group)

        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self._module_copies)
Example #7
0
 def forward(self, inputs, *targets, **kwargs):
     # input should be already scatterd
     # scattering the targets instead
     if not self.device_ids:
         return self.module(inputs, *targets, **kwargs)
     targets, kwargs = inputs(targets, kwargs, self.device_ids)
     if len(self.device_ids) == 1:
         return self.module(inputs, *targets[0], **kwargs[0])
     replicas = replicate(self.module, self.device_ids[:len(inputs)])
     outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
     return Reduce.apply(*outputs) / len(outputs)
def data_parallel(module,
                  inputs,
                  device_ids=None,
                  output_device=None,
                  dim=0,
                  module_kwargs=None,
                  dont_scatter=False,
                  dont_gather=False):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module: the module to evaluate in parallel
        inputs: inputs to the module
        device_ids: GPU ids on which to replicate module
        output_device: GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Variable containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs, )
    #print('getting device_ids')
    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))
    #print(device_ids)
    if output_device is None:
        output_device = device_ids[0]

    if dont_scatter == False:
        do_scatter_lists = isinstance(inputs[0], list)
        if do_scatter_lists:
            inputs, module_kwargs = scatter_lists(inputs, module_kwargs,
                                                  device_ids)
        else:
            inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs,
                                                   device_ids, dim)

    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    #print('getting used device_ids')
    used_device_ids = device_ids[:len(inputs)]
    #print(used_device_ids)
    #print('making model replicas')
    replicas = replicate(module, used_device_ids)
    #print('applying model')
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    if dont_gather:
        return tuple([[out[i] for out in outputs]
                      for i in range(len(outputs[0]))])
    #print('gathering result')
    return gather(outputs, output_device, dim)
Example #9
0
def my_data_parallel(module, inputs, device_ids=None, \
    dim=0, module_kwargs=None):
    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if len(inputs) == 1:
        return module(inputs[0])

    #print('my data parallel, len(inputs)', len(inputs))
    replicas = replicate(module, device_ids[:len(inputs)])
    outputs = my_parallel_apply(replicas, inputs, module_kwargs)
    return outputs 
def custom_data_parallel(module,
                         inputs,
                         device_ids=None,
                         output_device=None,
                         dim=0,
                         module_kwargs=None,
                         chunk_sizes=None):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module (Module): the module to evaluate in parallel
        inputs (Tensor): inputs to the module
        device_ids (list of int or torch.device): GPU ids on which to replicate module
        output_device (list of int or torch.device): GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Tensor containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
    output_device = _get_device_index(output_device, True)
    src_device_obj = torch.device("cuda:{}".format(device_ids[0]))

    for t in chain(module.parameters(), module.buffers()):
        if t.device != src_device_obj:
            raise RuntimeError("module must have its parameters and buffers "
                               "on device {} (device_ids[0]) but found one of "
                               "them on device: {}".format(
                                   src_device_obj, t.device))

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids,
                                           dim, chunk_sizes)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
Example #11
0
def data_parallel(module,
                  inputs,
                  device_ids=None,
                  output_device=None,
                  dim=0,
                  module_kwargs=None):
    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids,
                                           dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
Example #12
0
 def replicate(self, module, device_ids):
     modules = replicate(module, device_ids)
     execute_replication_callbacks(modules)
     return modules
    def __init__(self,
                 module,
                 device_ids=None,
                 distributed=True,
                 graph=None,
                 comm_device=None,
                 push_sum=True,
                 verbose=True):
        super(SimpleGossipDataParallel, self).__init__()

        # whether we're using multiple agents for training
        self.distributed = distributed

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.device_ids = device_ids
        self.output_device = self.device_ids[0]

        # put model on output device
        self.module = module.cuda(self.output_device)

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # prepare inter-node gossip objects
        if self.distributed:
            assert dist.is_initialized()

            # set distributed configuration properties
            self.graph = graph
            self.push_sum = push_sum
            self.gossip_enable = True
            if comm_device is None:
                comm_device = torch.device('cpu')
            self.comm_device = comm_device

            # logger used to print to stdout
            self.logger = make_logger(dist.get_rank(), verbose)

            # initalize gossiper to push-sum or push-pull protocol
            if self.push_sum:
                self.gossiper = PushSum(msg=_flatten_tensors(
                    list(self.module.parameters())),
                                        device=self.comm_device,
                                        graph=self.graph,
                                        logger=self.logger)
            else:
                self.gossiper = PushPull(msg=_flatten_tensors(
                    list(self.module.parameters())),
                                         device=self.comm_device,
                                         graph=self.graph,
                                         logger=self.logger)
        else:
            # logger used to print to stdout
            self.logger = make_logger(0, verbose)

        # register hook for gradient reduction on all GPUs avaialable locally
        self.register_backward_hook(self.__make_backward_hook())
Example #14
0
    def _ddp_init_helper(self):
        """
        Initialization helper function that does the following:

        (1) replicating the module from device[0] to the other devices
        (2) bucketing the parameters for reductions
        (3) resetting the bucketing states
        (4) registering the grad hooks
        (5) passing a handle of DDP to SyncBatchNorm Layer
        """
        if len(self.device_ids) > 1:
            # TODO: we don't need to replicate params in here. they're always going to
            # be broadcasted using larger blocks in broadcast_coalesced, so it might be
            # better to not pollute the caches with these small blocks
            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module

            for module_copy in self._module_copies[1:]:
                for param, copy_param in zip(self.module.parameters(),
                                             module_copy.parameters()):
                    copy_param.requires_grad = param.requires_grad

        else:
            self._module_copies = [self.module]

        self.modules_params = [
            list(m.parameters()) for m in self._module_copies
        ]
        self.modules_buffers = [list(m.buffers()) for m in self._module_copies]

        # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
        param_buckets = []

        # Split the parameters into buckets and by types as well
        # We only need to bucket and reduce parameters that require grad and
        # this is also true for backward since only the backward hooks for
        # parameters that require grad will be registered with gradient
        # reduction functions
        params_to_bucket = [[] for _ in self._module_copies]
        for dev_idx, m in enumerate(self._module_copies):
            for p in m.parameters():
                if p.requires_grad:
                    params_to_bucket[dev_idx].append(p)

        param_buckets = [
            dist._dist_bucket_tensors(dev_params_to_bucket,
                                      int(self.bucket_bytes_cap),
                                      fine_grained=False)
            for dev_params_to_bucket in params_to_bucket
        ]

        self.bucket_sizes = []
        self.bucket_map = {}

        # We transpose param_buckets, so the loop is over buckets.
        # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
        for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
            self.bucket_sizes.append(0)
            # Now, we transpose again, so we iterate over bucket_elems, but getting tuples
            # of params from each device.
            for param_tuple in zip(*param_buckets_tuple):
                if not param_tuple[0].requires_grad:
                    continue
                for p in param_tuple:
                    self.bucket_map[p] = (bucket_idx,
                                          self.bucket_sizes[bucket_idx])
                self.bucket_sizes[bucket_idx] += 1

        self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
                         for _ in range(len(self.device_ids))]
                        for i in range(len(self.bucket_sizes))]
        # The number of params ready in each bucket
        self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))]
                                   for i in range(len(self.bucket_sizes))]

        # coalesced bucket for only device 0
        self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
        # We will always reduce the bucket following the reverse order
        # that is, alway reduces following the order of: n - 1, n - 2, ..., 0
        self.next_bucket = len(self.bucket_sizes) - 1
        # When all buckets are reduced, this will be set to True. This flag is
        # useful for sanity checks to ensure that each iteration's backward has
        # always reduced all buckets
        self.all_buckets_reduced = False
        self.check_previous_reduction = False
        self.ready_buckets_not_reduced = set()
        self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
        self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
        self._register_grad_hooks()

        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self._module_copies)
    def __init__(self,
                 module,
                 device_ids=None,
                 distributed=True,
                 master_addr=None,
                 master_port=None,
                 backend=None,
                 world_size=None,
                 rank=None,
                 graph=None,
                 mixing=None,
                 comm_device=None,
                 lr=0.1,
                 momentum=0.9,
                 weight_decay=1e-4,
                 nesterov=True,
                 verbose=True):
        super(BilatGossipDataParallel, self).__init__()

        # whether we're using multiple agents for training
        self.distributed = distributed

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        # put model on output device
        self.module = module.cuda(self.output_device)

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # prepare inter-node gossip objects
        if self.distributed:

            # communicate over cpu's if not specified
            if comm_device is None:
                comm_device = torch.device('cpu')
            self.__cpu_comm = comm_device.type == 'cpu'

            # distributed backend config
            self.dist_config = {
                'verbose': verbose,
                'graph': graph,
                'master_addr': master_addr,
                'master_port': master_port,
                'backend': backend,
                'world_size': world_size,
                'rank': rank,
                'mixing': mixing,
                'lr': lr,
                'momentum': momentum,
                'nesterov': nesterov,
                'weight_decay': weight_decay
            }
            self.num_updates = 0

            # logger used to print to stdout
            self.logger = make_logger(rank, verbose)

            # prepare parameters for gossip
            self.gossip_enable = True
            self.gossip_params = []
            self.gossip_grads = []
            for p in module.parameters():
                cp = p.clone().detach_()
                cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
                cp.requires_grad = p.requires_grad
                self.gossip_params.append(cp)
                if p.requires_grad:
                    g = cp.clone().zero_().detach_()
                    g = g.cpu().pin_memory() if self.__cpu_comm else g.cuda()
                    self.gossip_grads.append(g)

            self.gossip_queue = mp.Queue()
            self.gossip_lock = mp.Lock()
            self.gossip_enable_flag = mp.Event()
            self.train_write_flag = mp.Event()  # signal train-proc write event
            self.gossip_read_flag = mp.Event()  # signal gossip-proc read event
            self.gossip_update_flag = mp.Event(
            )  # signal 2 gossip-proc need update
            self._lr = mp.Value('f', lr, lock=self.gossip_lock)
            self.gossip_thread = mp.Process(
                target=BilatGossipDataParallel._gossip_target,
                args=(self.dist_config, self.gossip_enable_flag,
                      self.train_write_flag, self.gossip_read_flag,
                      self.gossip_update_flag, self._lr, self.gossip_lock,
                      self.gossip_queue))
            self.gossip_thread.daemon = True
            self.gossip_thread.name = 'Gossip-Thread'
            self.gossip_thread.start()

            # pass handle to gossip_params and gossip_grads, and put in shared
            # memory
            self.gossip_queue.put((self.gossip_params, self.gossip_grads))

        else:
            # logger used to print to stdout
            self.logger = make_logger(0, verbose)

        # register ps/grad-reduction hooks
        self.__register_hooks()
Example #16
0
 def replicate(self, module, device_ids):
     if self.replicas is None:
         from torch.nn.parallel.replicate import replicate
         self.replicas = replicate(module, device_ids, not torch.is_grad_enabled())
     return self.replicas
    def __init__(self,
                 module,
                 device_ids=None,
                 distributed=True,
                 graph=None,
                 mixing=None,
                 comm_device=None,
                 push_sum=True,
                 rank=None,
                 world_size=None,
                 overlap=False,
                 synch_freq=0,
                 verbose=True):
        super(GossipDataParallel, self).__init__()

        # whether we're using multiple agents for training
        self.distributed = distributed

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        # put model on output device
        self.module = module.cuda(self.output_device)

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # prepare inter-node gossip objects
        if self.distributed:
            if world_size is None or rank is None:
                assert dist.is_initialized()
                rank = dist.get_rank()
                world_size = dist.get_world_size()

            # communicate over cpu's if not specified
            if comm_device is None:
                comm_device = torch.device('cpu')
            self.__cpu_comm = comm_device.type == 'cpu'

            # distributed backend config
            self.dist_config = {
                'verbose': verbose,
                'comm_device': comm_device,
                'graph': graph,
                'mixing': mixing,
                'push_sum': push_sum,
                'rank': rank,
                'world_size': world_size
            }
            self.overlap = overlap
            self.synch_freq = synch_freq
            self.num_updates = 0
            self.asynch = synch_freq > 0

            # logger used to print to stdout
            self.logger = make_logger(rank, verbose)

            # push-sum weight=1.0 ==> distributed averaging
            self.ps_weight = 1.0
            self.is_ps_numerator = False

            # prepare parameters for gossip
            self.gossip_enable = True
            self.gossiping = False
            self.params_mixed = True
            self.gossip_ps_factor = [None]
            self.gossip_ps_weight = [self.ps_weight]
            self.gossip_params = []
            self.gossip_device_buffer = []
            for p in module.parameters():
                cp = p.clone().detach_()
                cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
                self.gossip_params.append(cp)

            # prepare gossip process control objects
            self.gossip_lock = threading.Lock()
            self.gossip_flag = threading.Event()
            self.train_flag = threading.Event()
            self.gossip_thread = threading.Thread(
                target=GossipDataParallel._gossip_target,
                args=(self.dist_config, self.gossip_flag, self.train_flag,
                      self.gossip_lock, self.gossip_params,
                      self.gossip_device_buffer, self.gossip_ps_weight,
                      self.gossip_ps_factor))
            self.gossip_thread.daemon = True
            self.gossip_thread.name = 'Gossip-Thread'
            self.gossip_thread.start()
            # wait for thread to complete initialization
            self.gossip_flag.wait()
            self.gossip_flag.clear()
            # lazy mixing avoids additional bias/de-bias steps
            self.lazy_mixing = not self.asynch \
                and self.dist_config['mixing'].is_regular()
            self.lazy_ps_factor = self.gossip_ps_factor[0]
            self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing))
        else:
            self.params_mixed = True
            # logger used to print to stdout
            self.logger = make_logger(0, verbose)

        # register ps/grad-reduction hooks
        self.__register_hooks()
Example #18
0
    def parallel_forward(self, dense_x, lS_o, lS_i):
        ### prepare model (overwrite) ###
        # WARNING: # of devices must be >= batch size in parallel_forward call
        batch_size = dense_x.size()[0]
        ndevices = min(self.ndevices, batch_size, len(self.emb_l))
        device_ids = range(ndevices)
        # WARNING: must redistribute the model if mini-batch size changes(this is common
        # for last mini-batch, when # of elements in the dataset/batch size is not even
        if self.parallel_model_batch_size != batch_size:
            self.parallel_model_is_not_prepared = True

        if self.sync_dense_params or self.parallel_model_is_not_prepared:
            # replicate mlp (data parallelism)
            self.bot_l_replicas = replicate(self.bot_l, device_ids)
            self.top_l_replicas = replicate(self.top_l, device_ids)
            # distribute embeddings (model parallelism)
            t_list = []
            for k, emb in enumerate(self.emb_l):
                d = torch.device("cuda:" + str(k % ndevices))
                emb.to(d)
                t_list.append(emb.to(d))
            self.emb_l = nn.ModuleList(t_list)
            self.parallel_model_batch_size = batch_size
            self.parallel_model_is_not_prepared = False

        ### prepare input (overwrite) ###
        # scatter dense features (data parallelism)
        # print(dense_x.device)
        dense_x = scatter(dense_x, device_ids, dim=0)
        # distribute sparse features (model parallelism)
        if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
            sys.exit(
                "ERROR: corrupted model input detected in parallel_forward call"
            )

        t_list = []
        i_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            t_list.append(lS_o[k].to(d))
            i_list.append(lS_i[k].to(d))
        lS_o = t_list
        lS_i = i_list

        ### compute results in parallel ###
        # bottom mlp
        # WARNING: Note that the self.bot_l is a list of bottom mlp modules
        # that have been replicated across devices, while dense_x is a tuple of dense
        # inputs that has been scattered across devices on the first (batch) dimension.
        # The output is a list of tensors scattered across devices according to the
        # distribution of dense_x.
        x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
        # debug prints
        # print(x)

        # embeddings
        ly = self.apply_emb(lS_o, lS_i, self.emb_l)
        # debug prints
        # print(ly)

        # butterfly shuffle (implemented inefficiently for now)
        # WARNING: Note that at this point we have the result of the embedding lookup
        # for the entire batch on each device. We would like to obtain partial results
        # corresponding to all embedding lookups, but part of the batch on each device.
        # Therefore, matching the distribution of output of bottom mlp, so that both
        # could be used for subsequent interactions on each device.
        if len(self.emb_l) != len(ly):
            sys.exit(
                "ERROR: corrupted intermediate result in parallel_forward call"
            )

        t_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            y = scatter(ly[k], device_ids, dim=0)
            t_list.append(y)
        # adjust the list to be ordered per device
        ly = list(map(lambda y: list(y), zip(*t_list)))
        # debug prints
        # print(ly)

        # interactions
        z = []
        for k in range(ndevices):
            zk = self.interact_features(x[k], ly[k])
            z.append(zk)
        # debug prints
        # print(z)

        # top mlp
        # WARNING: Note that the self.top_l is a list of top mlp modules that
        # have been replicated across devices, while z is a list of interaction results
        # that by construction are scattered across devices on the first (batch) dim.
        # The output is a list of tensors scattered across devices according to the
        # distribution of z.
        p = parallel_apply(self.top_l_replicas, z, None, device_ids)

        ### gather the distributed results ###
        p0 = gather(p, self.output_d, dim=0)

        # clamp output if needed
        if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
            z0 = torch.clamp(p0,
                             min=self.loss_threshold,
                             max=(1.0 - self.loss_threshold))
        else:
            z0 = p0

        return z0
 def replicate(self, module, device_ids):
     return replicate(module, device_ids, not torch.is_grad_enabled())
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DistributedDataParallel, self).__init__()

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]
        self.dim = dim
        self.module = module
        self.device_ids = device_ids
        self.output_device = output_device

        # Sync params and buffers
        # broad the model from master(node with rank 0) to all nodes
        # we can borrow this in our parameter server settings
        for p in self.module.state_dict().values():
            dist.broadcast(p, 0)

        if len(device_ids) > 1:
            # TODO: we don't need to replicate params in here. they're always going to
            # be broadcasted using larger blocks in broadcast_coalesce, so it might be
            # better to not pollute the caches with these small blocks
            self._module_copies = replicate(self.module, self.device_ids)
            self._module_copies[0] = self.module
            for module_copy in self._module_copies[1:]:
                for param, copy_param in zip(self.module.parameters(), module_copy.parameters()):
                    copy_param.detach_()
                    copy_param.requires_grad = param.requires_grad
        else:
            self._module_copies = [self.module]

        # Split parameters into buckets that will coalesce reductions
        # TODO: different types need different buckets
        t = None
        for p in self.module.parameters():
            tp = type(p.data)
            if t is not None and t is not tp:
                raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type")
            t = tp

        self.bucket_sizes = []
        self.bucket_map = {}
        MB = 1024 * 1024
        self.broadcast_bucket_size = 10 * MB  # used for param sync before forward
        bucket_bytes_cap = 1 * MB
        bucket_bytes = bucket_bytes_cap  # to init the first bucket immediately
        for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)):
            if bucket_bytes >= bucket_bytes_cap:
                self.bucket_sizes.append(0)
                bucket_bytes = 0
            self.bucket_sizes[-1] += 1
            for p in param_tuple:
                self.bucket_map[p] = len(self.bucket_sizes) - 1
            bucket_bytes += p.numel() * p.element_size()

        self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
        self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
        self.reduced = [False] * len(self.bucket_sizes)

        self._register_grad_hooks()

        self.dispatch_lock = threading.Lock()
        self._start_reduction_threads()
 def replicate(self, module, device_ids):
     return replicate(module, device_ids)
Example #22
0
    def __init__(self,
                 module,
                 device_ids=None,
                 distributed=True,
                 world_size=None,
                 rank=None,
                 comm_device=None,
                 verbose=True):
        super(AllReduceDataParallel, self).__init__()

        # whether we're using multiple agents for training
        self.distributed = distributed

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        # put model on output device
        self.module = module.cuda(self.output_device)

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # prepare inter-node gossip objects
        if self.distributed:
            if world_size is None or rank is None:
                assert dist.is_initialized()
                world_size = dist.get_world_size()
                rank = dist.get_rank()

            # communicate over cpu's if not specified
            if comm_device is None:
                comm_device = torch.device('cpu')
            self.__cpu_comm = comm_device.type == 'cpu'

            # distributed backend config
            self.dist_config = {
                'verbose': verbose,
                'comm_device': comm_device,
                'rank': rank
            }

            # logger used to print to stdout
            self.logger = make_logger(rank, verbose)

            # prepare parameters for gossip
            self.ps_factor = 1. / world_size
            self.gossip_enable = True
            self.gossiping = False
            self.params_mixed = True
            self.gossip_params = []
            self.gossip_device_buffer = []
            for p in module.parameters():
                cp = p.clone().detach_()
                cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
                self.gossip_params.append(cp)

            # prepare gossip process control objects
            self.gossip_flag = threading.Event()
            self.train_flag = threading.Event()
            self.gossip_thread = threading.Thread(
                target=AllReduceDataParallel._gossip_target,
                args=(self.dist_config, self.gossip_flag, self.train_flag,
                      self.gossip_params, self.gossip_device_buffer))
            self.gossip_thread.daemon = True
            self.gossip_thread.name = 'Gossip-Thread'
            self.gossip_thread.start()
            # wait for thread to complete initialization
            self.gossip_flag.wait()
            self.gossip_flag.clear()
        else:
            self.params_mixed = True
            # logger used to print to stdout
            self.logger = make_logger(0, verbose)

        # register grad-reduction hooks
        self.__register_hooks()
Example #23
0
 def replicate(self, flow, device_ids):
     return replicate(flow, device_ids)
Example #24
0
    def __init__(self,
                 module,
                 device_ids=None,
                 rank=None,
                 world_size=None,
                 graph=None,
                 mixing=None,
                 comm_device=None,
                 push_sum=True,
                 overlap=False,
                 synch_freq=0,
                 verbose=False,
                 use_streams=True,
                 nprocs_per_node=1,
                 local_node_group=None):
        super(GossipDataParallel, self).__init__()

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        self.nprocs_per_node = nprocs_per_node

        if world_size is None or rank is None:
            assert dist.is_initialized()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        self.process_rank = rank

        if self.nprocs_per_node > 1:
            self.local_rank = self.process_rank % self.nprocs_per_node
            world_size //= nprocs_per_node
            rank //= nprocs_per_node
            if local_node_group is None:
                for node in range(world_size):
                    node_processes_ranks = list(
                        range(node * self.nprocs_per_node,
                              (node + 1) * self.nprocs_per_node))
                    # Process group to communicate between processes on this
                    # machine
                    new_local_group = create_process_group(
                        node_processes_ranks)
                    if self.process_rank in node_processes_ranks:
                        self.local_node_group = new_local_group
            else:
                self.local_node_group = local_node_group
        else:
            self.local_rank = 0

        # put model on output device
        self.module = module
        first_param_dtype = next(self.module.parameters()).dtype

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # choose communication device based on backend
        if comm_device is None:
            cpu_comm = True if dist.get_backend() == 'gloo' else False
            comm_device = torch.device('cpu') if cpu_comm else torch.device(
                'cuda')
        self.__cpu_comm = comm_device.type == 'cpu'

        if graph is None:
            graph = NPDDEGraph(rank, world_size, self.nprocs_per_node,
                               self.local_rank)

        if mixing is None:
            mixing = UniformMixing(graph, comm_device)

        # distributed backend config
        self.dist_config = {
            'verbose': verbose,
            'comm_device': comm_device,
            'graph': graph,
            'mixing': mixing,
            'push_sum': push_sum,
            'rank': rank,
            'process_rank': self.process_rank,
            'world_size': world_size,
            'cpu_comm': self.__cpu_comm
        }
        self.overlap = overlap
        self.synch_freq = synch_freq
        self.num_updates = 0
        self.asynch = synch_freq > 0

        # logger used to print to stdout
        self.logger = make_logger(rank, verbose)

        # push-sum weight=1.0 ==> distributed averaging
        self.ps_weight = torch.ones(1,
                                    device=comm_device).type(first_param_dtype)
        self.nprocs_per_node_device = torch.tensor([self.nprocs_per_node],
                                                   device=comm_device,
                                                   dtype=first_param_dtype)
        self.is_ps_numerator = False

        # prepare parameters for gossip
        self.gossip_enable = True
        self.gossiping = False
        self.params_mixed = True
        self.gossip_ps_factor = torch.zeros(
            1, device=comm_device).type(first_param_dtype)
        self.gossip_ps_weight = self.ps_weight.clone()
        self.gossip_params = []
        self.gossip_device_buffer = []
        for p in module.parameters():
            cp = p.clone().detach_()
            cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
            self.gossip_params.append(cp)
            self.gossip_device_buffer.append(cp)

        # prepare gossip process control objects
        self.gossip_lock = threading.Lock()
        self.gossip_flag = threading.Event()
        self.train_flag = threading.Event()

        if self.dist_config['comm_device'].type != 'cpu' and use_streams:
            self.gossip_stream = torch.cuda.Stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()

        if self.process_rank % self.nprocs_per_node == 0:
            self.gossip_thread = threading.Thread(
                target=GossipDataParallel._gossip_target,
                args=(self.dist_config, self.gossip_flag, self.train_flag,
                      self.gossip_lock, self.gossip_params,
                      self.gossip_device_buffer, self.gossip_ps_weight,
                      self.gossip_ps_factor, self.gossip_stream))
            self.gossip_thread.daemon = True
            self.gossip_thread.name = 'Gossip-Thread'
            self.gossip_thread.start()
        else:
            self.gossip_flag.set()
        # wait for thread to complete initialization
        self.gossip_flag.wait()
        self.gossip_flag.clear()
        # lazy mixing avoids additional bias/de-bias steps
        self.lazy_mixing = (not self.asynch
                            and self.dist_config['mixing'].is_regular()
                            and not self.overlap)
        self.lazy_ps_factor = self.gossip_ps_factor.clone()
        self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing))

        # register ps/grad-reduction hooks
        self.__register_hooks()