Beispiel #1
0
    def forward(self, *inputs, **kwargs):
        """Execute forward propagation

        Arguments:
            *inputs: Variable length input list
            **kwargs: variable length keyword arguments
        """

        if self.mp_world_size > 1:
            if self.mpu is None:
                for input in inputs:
                    if torch.is_tensor(input):
                        input = input.to(torch.cuda.current_device())
                        if not input.is_contiguous():
                            input = input.contiguous()
                        dist.broadcast(input, 0)
                for k in kwargs:
                    if torch.is_tensor(kwargs[k]):
                        kwargs[k] = kwargs[k].to(torch.cuda.current_device())
                        if not kwargs[k].is_contiguous():
                            kwargs[k] = kwargs[k].contiguous()
                        dist.broadcast(kwargs[k], 0)
            outputs = self.model_orig_fwd(*inputs, **kwargs)
        else:
            if self.enable_cuda_graph:
                if self.cuda_graph_created:
                    outputs = self._graph_replay(*inputs, **kwargs)
                else:
                    self._create_cuda_graph(*inputs, **kwargs)
                    outputs = self._graph_replay(*inputs, **kwargs)
            else:
                outputs = self.module(*inputs, **kwargs)
            #outputs = self.module(*inputs, **kwargs)
        return outputs
Beispiel #2
0
 def _synchronize_tied_weights(self):
     for key, comm in self.tied_comms.items():
         dist.broadcast(
             getattr(comm['module'],
                     comm['weight_attr']),
             src=min(comm['ranks']),
             group=comm['group'],
         )
Beispiel #3
0
def get_lst_from_rank0(lst: List[int]) -> None:
    """
    NOTE: creates both communication and synchronization overhead so should be used
    sparingly
    """
    lst_tensor = torch.tensor(
        lst if dist.get_rank() == 0 else [-1] * len(lst),
        dtype=int,
        # device=torch.cuda.current_device(),
        device=torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])),
        requires_grad=False,
    )
    dist.broadcast(lst_tensor, src=0, async_op=False)

    return list(lst_tensor.cpu().numpy())
Beispiel #4
0
def test_partitioned_tensor_meta():
    world = dist.get_world_size()
    rank = dist.get_rank()

    group = dist.new_group(ranks=list(range(world)))

    rows = world * 7
    cols = 3

    full = torch.rand(rows, cols).cuda()
    dist.broadcast(full, src=0, group=group)
    part = PartitionedTensor(full, group=group)

    my_meta = PartitionedTensor.from_meta(part.to_meta(), part.local_data,
                                          group)
    assert torch.equal(full, my_meta.full())
Beispiel #5
0
def test_partitioned_tensor():
    world = dist.get_world_size()
    rank = dist.get_rank()

    group = dist.new_group(ranks=list(range(world)))

    rows = world * 4
    cols = 3

    full = torch.rand(rows, cols).cuda()
    dist.broadcast(full, src=0, group=group)
    part = PartitionedTensor(full, group=group)

    assert len(part.local_size()) == 1
    assert part.local_size()[0] * world == full.numel()

    reconstructed = part.full()
    assert torch.equal(full, reconstructed)
Beispiel #6
0
def recv(tensor, src_stage, async_op=False):
    global _groups
    assert async_op == False, "Doesn't support async_op true"
    dest_stage = _grid.get_stage_id()
    _is_valid_send_recv(src_stage, dest_stage)

    src_rank = _grid.stage_to_global(stage_id=src_stage)

    if async_op:
        global _async
        op = dist.irecv(tensor, src_rank)
        _async.append(op)
    else:
        if can_send_recv():
            return dist.recv(tensor, src_rank)
        else:
            group = _get_send_recv_group(src_stage, dest_stage)
            return dist.broadcast(tensor,
                                  src_rank,
                                  group=group,
                                  async_op=async_op)
Beispiel #7
0
    def __init__(self,
                 model,
                 triangular_masking=True,
                 mp_size=1,
                 training_mp_size=1,
                 ep_size=1,
                 mpu=None,
                 ep_group=None,
                 expert_mp_group=None,
                 checkpoint=None,
                 dtype=None,
                 injection_dict=None,
                 return_tuple=True,
                 replace_method='auto',
                 quantization_setting=None,
                 replace_with_kernel_inject=False,
                 moe=False,
                 moe_experts=1,
                 moe_type='standard',
                 config=None,
                 enable_cuda_graph=False,
                 save_mp_checkpoint_path=None):
        """
        Args:
            model: torch.nn.Module
            mp_size: model-parallel size
            mpu: model-parallel unit (used for Megatron-type models)
            checkpoint: the json-path, showing the address of model-checkpoints
                Example: {type: 'Megatron', 'checkpoints': [ckpt_mp0.pt, ckpt_mp1.pt], 'version': 1.0}
            dtype: data-type by which inference is executed
            injection_dict: the dictionary that shows the injection policy:
                Example: {BertLayer: HFBertLayerPolicy}
            return_tuple: if true, inference-API returns a tuple, otherwise a tensor
            replace_method: the injection method, this can be passed as auto if no injection-policy is defined, in which case the injection is automatic based on the available policies
            quantization_setting:
                one of None, Tuple(mlp_extra_grouping, quantize_groups), quantize_groups
            replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
            the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
        """
        global DS_INFERENCE_ENABLED
        DS_INFERENCE_ENABLED = True

        super().__init__()

        self.module = model

        self._get_model_config_generate(config)

        if hasattr(self.module, "config"):
            DSPolicy.hf_model_config = self.module.config

        self.mp_world_size = mp_size
        self.checkpoint = checkpoint
        self.dtype = dtype
        self.injection_dict = injection_dict
        self.mp_group = None
        self.mpu = mpu
        self._validate_args(mpu)
        self.replace_method = replace_method
        self.quantize_merge_count = 1
        self.quantization_scales = None
        self.triangular_masking = triangular_masking
        self.ep_size = ep_size
        self.ep_group = ep_group
        self.expert_mp_group = expert_mp_group
        self.enable_cuda_graph = enable_cuda_graph
        self.cuda_graph_created = False
        self.checkpoint_engine = TorchCheckpointEngine()
        self._init_quantization_setting(quantization_setting)

        if enable_cuda_graph:
            assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
                "If you want to use cuda graph, please upgrade torch to at least v1.10"

        if self.checkpoint and not replace_with_kernel_inject:
            self._load_checkpoint(self.checkpoint)

        # convert model to intended dtype
        if self.dtype:
            self._convert_to_dtype()

        if self.mpu:
            self.mp_world_size = dist.get_world_size(
                group=self.mpu.get_model_parallel_group())
            self.mp_group = mpu.get_model_parallel_group()
        elif self.mp_world_size > 1:
            self._create_model_parallel_group()

        moe, _ = has_moe_layers(self.module)

        if moe and dist.get_world_size() > 1:
            self._create_ep_parallel_group(moe_experts)

        if self.injection_dict:
            for client_module, injection_policy in self.injection_dict.items():
                self._apply_injection_policy(
                    client_module,
                    injection_policy,
                    return_tuple,
                    replace_with_kernel_inject,
                    moe,
                    moe_experts,
                    moe_type,
                    training_mp_size,
                    self.checkpoint if replace_with_kernel_inject else None,
                    save_mp_checkpoint_path=save_mp_checkpoint_path)
        elif replace_method == 'auto':
            self._apply_injection_policy(
                return_tuple=return_tuple,
                replace_with_kernel_inject=replace_with_kernel_inject,
                moe=moe,
                moe_experts=moe_experts,
                moe_type=moe_type,
                training_mp_size=training_mp_size,
                checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None,
                save_mp_checkpoint_path=save_mp_checkpoint_path)

        device = torch.cuda.current_device()
        self.module.to(device)

        if self.mp_world_size > 1:
            _rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device())
            dist.broadcast(_rng_state, 0)
            torch.cuda.set_rng_state(_rng_state.cpu())

        if self.mp_world_size > 1:
            self.model_orig_fwd = self.module.forward
            self.module.forward = self.forward
        else:
            self.module.register_forward_pre_hook(self._pre_forward_hook)