Exemplo n.º 1
0
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DGLGraphDataParallel, self).__init__()
        self.use_cuda = True

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            self.use_cuda = False
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device("cuda:{}".format(
            self.device_ids[0]))

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])
Exemplo n.º 2
0
    def __init__(self, flow: Flow, device_ids=None, output_device=None, dim=0):
        super(DataParallelFlow, self).__init__(flow.inverse)

        if not torch.cuda.is_available():
            self.flow = flow
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]
        self.dim = dim
        self.flow = flow
        self.device_ids = device_ids
        self.output_device = output_device

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.flow.cuda(device_ids[0])
Exemplo n.º 3
0
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])
Exemplo n.º 4
0
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(ModuleParallel, self).__init__()

        if not T.cuda.is_available():
            self.module = module
            self.device_ids = []
        else:
            if device_ids is None:
                device_ids = list(range(T.cuda.device_count()))
            if output_device is None:
                output_device = device_ids[0]

            self.dim = dim
            self.module = module
            self.device_ids = list(
                map(lambda x: _get_device_index(x, True), device_ids))
            self.output_device = _get_device_index(output_device, True)

            _check_balance(self.device_ids)

            if len(self.device_ids) == 1:
                self.module.cuda(device_ids[0])

        self._def_methods = {
            mn
            for mn, m in inspect.getmembers(Module(), inspect.ismethod)
        }
        self._def_methods.remove('forward')
        self.make_module_methods_parallel()

        def_attrs = {
            k
            for k, v in Module().__dict__.items() if not inspect.ismethod(v)
        }
        for k, v in module.__dict__.items():
            if k is not inspect.ismethod(v) and k not in def_attrs:
                setattr(self, k, v)