示例#1
0
文件: api.py 项目: xsacha/pytorch
    def gather(
        self,
        dst: int = 0,
        out: Optional[torch.Tensor] = None,
    ) -> None:
        """
        Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the
        sharded tensor.

        The API needs to be called on all ranks in SPMD fashion. All ranks should have
        the same ``dst``. ``out`` should be a tensor of the same size as the overall
        size of the sharded tensor on ``dst`` and ``None`` on all other ranks.

        Args:
            dst(int): The rank where full tensor is constructed.
                Default: 0
            out (:class `torch.Tensor`, optional): The output full tensor.
                Must to be provided ONLY on ``dst`` rank.
                Default: ``None``
        """
        rank = dist.get_rank(self._process_group)
        full_size = self.metadata().size
        _validate_output_tensor_for_gather(rank, dst, full_size, out)

        local_shards = self.local_shards()

        world_size = dist.get_world_size(self._process_group)

        gathered_shards = [None] * world_size
        # will revise this part with CPU support and use dist.gather()
        # once NCCL support for gather() is ready
        # https://github.com/pytorch/pytorch/issues/66187
        device = torch.device(f"cuda:{rank % world_size}")
        with torch.cuda.device(device):
            dist.all_gather_object(
                obj=local_shards,
                object_list=gathered_shards,
                group=self._process_group,
            )

        if rank == dst:
            dims = len(full_size)
            for shards in gathered_shards:
                if shards is None:
                    raise RuntimeError(
                        'Gathered shards cannot be None on dst rank {dst}'
                    )
                for shard in shards:
                    metadata = shard.metadata
                    tensor = shard.tensor

                    out_narrow_view = out
                    for dim in range(dims):
                        out_narrow_view = out_narrow_view.narrow(
                            dim,
                            metadata.shard_offsets[dim],
                            metadata.shard_lengths[dim],
                        )

                    out_narrow_view.copy_(tensor)
示例#2
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        *global_size,
        process_group=None,
        init_rrefs=False,
    ):
        # STEP 1: Validate the Shardmetadatas locally
        process_group = (
            process_group
            if process_group is not None
            else distributed_c10d._get_default_group()
        )
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
        global_tensor_size = _flatten_tensor_size(global_size)

        if len(local_shards) > 0:
            local_sharded_tensor_metadata = \
                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)

        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
        # metadata by gathering local ShardedTensorMetadata
        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
        if world_size > 1:
            gathered_metadatas = [None for _ in range(world_size)]

            dist.all_gather_object(
                gathered_metadatas,
                local_sharded_tensor_metadata,
                group=process_group
            )
        else:
            gathered_metadatas = [local_sharded_tensor_metadata]

        global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
        tensor_properties = global_sharded_tensor_metadata.tensor_properties

        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
        # prepare initialization
        spec = shard_spec._infer_sharding_spec_from_shards_metadata(
            global_sharded_tensor_metadata.shards_metadata
        )
        sharded_tensor = cls.__new__(cls,
                                     spec,
                                     global_sharded_tensor_metadata.size,
                                     dtype=tensor_properties.dtype,
                                     layout=tensor_properties.layout,
                                     pin_memory=tensor_properties.pin_memory,
                                     requires_grad=tensor_properties.requires_grad)
        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)

        # attach local_shards to the ShardedTensor created
        sharded_tensor._local_shards = local_shards

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
示例#3
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        *global_size,
        process_group=None,
        init_rrefs=False,
    ):
        # STEP 1: Validate the Shardmetadatas locally
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
        local_shards_device = torch.device("cpu")
        global_tensor_size = _flatten_tensor_size(global_size)

        if len(local_shards) > 0:
            local_sharded_tensor_metadata, local_shards_device = \
                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)

        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
        # metadata by gathering local ShardedTensorMetadata
        gathered_metadatas = [None for _ in range(world_size)]

        if local_shards_device.type == "cuda":
            # with GPU/NCCL, we need to set a device for all_gather_object
            # to use as we need to know which device we should put the
            # serialized tensor on before the NCCL collective.
            with torch.cuda.device(local_shards_device):
                dist.all_gather_object(gathered_metadatas,
                                       local_sharded_tensor_metadata,
                                       group=process_group)
        else:
            dist.all_gather_object(gathered_metadatas,
                                   local_sharded_tensor_metadata,
                                   group=process_group)

        global_sharded_tensor_metadata = build_global_metadata(
            gathered_metadatas)

        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
        # prepare initialization
        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        # add to metadata and local_shards
        sharded_tensor._metadata = global_sharded_tensor_metadata
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(
            global_sharded_tensor_metadata.shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
示例#4
0
def _shard_tensor(tensor: torch.Tensor,
                  sharding_spec: ShardingSpec,
                  src_rank=0,
                  process_group=None) -> ShardedTensor:
    """
    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Returns:
        A :class:`ShardedTensor` sharded from the given tensor.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    if not tensor.is_contiguous():
        raise ValueError('input tensor is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    current_rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {current_rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {current_rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    st = sharding_spec.shard(tensor,
                             src_rank=src_rank,
                             process_group=process_group)

    return st
示例#5
0
    def _verify_sequence_number_across_pg(self, pg, verify_pg):

        seq_num = pg._get_sequence_number_for_group()
        obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
        # We use a separate pg to verify the sequence numbers, otherwise these
        # collectives will themselves increment the sequence number.
        dist.all_gather_object(obj_list, seq_num, group=verify_pg)
        self.assertEqual(len(set(obj_list)), 1)
        return obj_list[0]
示例#6
0
def _validate(model, process_group, assert_fn):
    module_states = [param.detach().cpu() for param in model.parameters()]
    module_states.extend([buffer.detach().cpu() for buffer in model.buffers()])
    world_size = dist.get_world_size(process_group)
    olist = [None for _ in range(world_size)]
    dist.all_gather_object(olist, module_states, group=process_group)
    rank0_states = olist[0]
    for state in olist[1:]:
        for p1, p2 in zip(rank0_states, state):
            assert_fn(p1, p2)
示例#7
0
    def all_gather_object(self, object: T) -> List[T]:
        """
        Same as c10d::all_gather_object but works without distributed enabled.
        """
        if self.use_dist:
            gather_objs = cast(List[T],
                               [None] * dist.get_world_size(self.group))

            dist.all_gather_object(object_list=gather_objs,
                                   obj=object,
                                   group=self.group)
        else:
            gather_objs = [object]
        return gather_objs
示例#8
0
def all_gather(data):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]
    data_list = [None] * world_size
    dist.all_gather_object(data_list, data)
    return data_list
示例#9
0
    def _test_sequence_num_set_default_pg(self, backend):
        store = dist.FileStore(self.file_name, self.world_size)
        dist.init_process_group(
            backend,
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )

        default_pg = c10d._get_default_group()
        seq_num = default_pg._get_sequence_number_for_group()
        obj_list = [None for _ in range(dist.get_world_size())]
        dist.all_gather_object(obj_list, seq_num)
        self.assertEqual(len(set(obj_list)), 1)
示例#10
0
    def _test_sequence_num_set_new_group(self, backend):
        store = c10d.FileStore(self.file_name, self.world_size)
        dist.init_process_group(
            backend,
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )

        subgroup = dist.new_group([0, 1])
        subgroup_seq = subgroup._get_sequence_number_for_group()
        obj_list = [None for _ in range(dist.get_world_size())]
        dist.all_gather_object(obj_list, subgroup_seq)
        self.assertEqual(len(set(obj_list)), 1)
示例#11
0
    def _test_sequence_num_incremented(self, process_group, ranks):
        # verify initial sequence numbers. Use a distinct process group for
        # verification to keep counts as expected with respect to process_group.
        verify_pg = dist.new_group(
            ranks=ranks,
            backend="gloo",
        )
        assert dist.get_world_size(process_group) == dist.get_world_size(
            verify_pg)

        initial_num = (
            self._verify_sequence_number_across_pg(pg=process_group,
                                                   verify_pg=verify_pg)
            if not c10d.distributed_c10d._rank_not_in_group(process_group) else
            -1)

        # Verify sequence numbers are appropriately incremented
        for i in range(10):
            t = torch.ones(1, device=torch.cuda.current_device())
            dist.all_reduce(t, group=process_group)
            if not c10d.distributed_c10d._rank_not_in_group(process_group):
                seq_num = self._verify_sequence_number_across_pg(
                    pg=process_group,
                    verify_pg=verify_pg,
                )
                self.assertEqual(initial_num + i + 1, seq_num)

        if dist.get_world_size(process_group) > 2:
            # Test when certain ranks don't call collectives
            if dist.get_rank(process_group) not in [0, 2]:
                dist.all_reduce(t, group=process_group, async_op=True)
            # Now ranks 0 and 2 should be lagging by 1.
            if not c10d.distributed_c10d._rank_not_in_group(process_group):
                seq_num = process_group._get_sequence_number_for_group()
                rank = dist.get_rank(process_group)
                obj_list = [
                    None for _ in range(dist.get_world_size(verify_pg))
                ]
                dist.all_gather_object(obj_list, (rank, seq_num),
                                       group=verify_pg)
                rank_to_seq_num = {rank: num for (rank, num) in obj_list}
                self.assertEqual(len(set(rank_to_seq_num.values())), 2)
                self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
                expected_same = {
                    rank_to_seq_num[i]
                    for i in rank_to_seq_num.keys() if i not in [0, 2]
                }
                self.assertEqual(len(expected_same), 1)
                self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
示例#12
0
def all_gather_object(data: Any) -> List[Any]:
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)

    Note:
        For NCCL-based processed groups, internal tensor representations of objects
        must be moved to the GPU device before communication takes place. In this case,
        the device used is given by torch.cuda.current_device() and it is the user’s
        responsibility to ensure that this is set so that each rank has an individual
        GPU, via torch.cuda.set_device().

    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    if get_world_size() == 1:
        return [data]

    data_list = [None] * get_world_size()
    dist.all_gather_object(data_list, data)
    return data_list
示例#13
0
def _tensorpipe_exchange_and_check_all_device_maps(my_name, my_device_count,
                                                   my_device_maps, my_devices,
                                                   group):
    gathered: List[Tuple[str, int, Dict[str, Dict[torch.device, torch.device]],
                         List[torch.device]]] = [("", 0, {}, [])
                                                 for _ in range(group.size())]
    dist.all_gather_object(
        gathered, (my_name, my_device_count, my_device_maps, my_devices),
        group)
    all_names = [name for name, _, _, _ in gathered]
    all_device_counts = {name: count for name, count, _, _ in gathered}
    all_device_maps = {name: map_ for name, _, map_, _ in gathered}
    all_devices = {name: devices for name, _, _, devices in gathered}

    _validate_device_maps(all_names, all_device_counts, all_device_maps,
                          all_devices)

    # passed all checked, construct reverse mapping and get list of devices handled by this agent
    reverse_device_maps = _create_reverse_mapping(my_name, all_names,
                                                  all_device_maps)
    my_devices = _create_device_list(my_devices, my_device_maps,
                                     reverse_device_maps)
    return reverse_device_maps, my_devices
示例#14
0
def _assert_module_states(
    model: nn.Module,
    process_group: dist.ProcessGroup,
    assert_fn: Callable,
):
    """
    All-gathers module states across ranks and calls ``assert_fn`` on each pair
    of corresponding states from rank 0 and a nonzero rank. For example, if
    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
    states are equal across ranks.
    """
    # Include names for debugging convenience
    named_module_states = [(param_name, param.detach().cpu())
                           for param_name, param in model.named_parameters()]
    named_module_states += [(buffer_name, buffer.detach().cpu())
                            for buffer_name, buffer in model.named_buffers()]
    world_size = dist.get_world_size(process_group)
    olist = [None for _ in range(world_size)]
    dist.all_gather_object(olist, named_module_states, group=process_group)
    rank0_states = olist[0]
    for state in olist[1:]:
        for (_, p1), (_, p2) in zip(rank0_states, state):
            assert_fn(p1, p2)
示例#15
0
def all_gather(data, group=None):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors).

    Args:
        data: any picklable object
        group: a torch process group. By default, will use a group which
            contains all ranks on gloo backend.

    Returns:
        list[data]: list of data gathered from each rank
    """
    if get_world_size() == 1:
        return [data]
    if group is None:
        group = _get_global_gloo_group(
        )  # use CPU group by default, to reduce GPU RAM usage.
    world_size = dist.get_world_size(group)
    if world_size == 1:
        return [data]

    output = [None for _ in range(world_size)]
    dist.all_gather_object(output, data, group=group)
    return output
示例#16
0
def main_per_process(rank, world_size, args):
    init_process(rank, world_size)
    start_epoch = 1
    if args.wandb and rank == 0:
        run_name = get_run_name(args)
        wandb.init(project='myproject', entity='myaccount')
        wandb.run.name = run_name
        wandb.config.update(args)
    if rank == 0:
        output_cuda_info()

    # load dataset
    train_val_split = 0.2
    batch_size_per_proc = int(args.batch_size / world_size)
    train_set, val_set, test_set = load_cifar10(train_val_split,
                                                args.pretrained)

    # create sampler for ddp
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_set, num_replicas=world_size, rank=rank, shuffle=False)

    # create data loader for ddp
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size_per_proc,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=batch_size_per_proc,
                                             num_workers=args.num_workers,
                                             pin_memory=True,
                                             sampler=val_sampler)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    # create ddp model
    model = make_model(args.model,
                       10,
                       pretrained=args.pretrained,
                       fix_param=args.fixparam)
    model = model.to(rank)
    # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    if rank == 0:
        output_summary(ddp_model, train_loader)

    # settings for training
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epoch)

    # synchronize
    dist.barrier()

    # start training
    print(f'[{datetime.now()}]#{rank}: start training')
    for epoch in range(start_epoch, start_epoch + args.epoch):
        if rank == 0:
            print(f'Epoch[{epoch}/{args.epoch}]')
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)
        dist.barrier()  # synchronize

        # train and validate
        train_loss, train_acc = train_epoch(ddp_model, train_loader, optimizer,
                                            criterion, rank)
        val_loss, val_acc = validate_epoch(ddp_model, val_loader, criterion,
                                           rank)
        dist.barrier()  # synchronize

        # sharing loss and accuracy among all gpus(processes)
        train_loss_list = [0.] * world_size
        train_acc_list = [0.] * world_size
        val_loss_list = [0.] * world_size
        val_acc_list = [0.] * world_size
        dist.all_gather_object(train_loss_list, train_loss)
        dist.all_gather_object(train_acc_list, train_acc)
        dist.all_gather_object(val_loss_list, val_loss)
        dist.all_gather_object(val_acc_list, val_acc)

        # save data to wandb
        if args.wandb and rank == 0:
            avg_train_loss = sum(train_loss_list) / world_size
            avg_train_acc = sum(train_acc_list) / world_size
            avg_val_loss = sum(val_loss_list) / world_size
            avg_val_acc = sum(val_acc_list) / world_size
            wandb.log({
                'acc': avg_train_acc,
                'loss': avg_train_loss,
                'val_acc': avg_val_acc,
                'val_loss': avg_val_loss,
                'lr': scheduler.get_last_lr()[0]
            })
        scheduler.step()
    print(f'[{datetime.now()}]#{rank}: finished training')

    if rank == 0:
        print('# final test')
        test_loss, test_acc, class_acc = final_test(model, test_loader,
                                                    criterion, rank)
        for key, value in class_acc.items():
            print(f'{key} : {value: .3f}')

        # save data to wandb
        if args.wandb:
            wandb.log({'test_acc': test_acc, 'test_loss': test_loss})
            wandb.finish()
        print('# all finished')
示例#17
0
文件: api.py 项目: skn123/pytorch
def _shard_tensor(tensor: torch.Tensor,
                  sharding_spec: ShardingSpec,
                  src_rank=0,
                  process_group=None):
    """
    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Returns:
        A :class:`ShardedTensor` sharded from the given tensor.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise NotImplementedError('Only ChunkShardingspec is supported.')
    if not tensor.is_contiguous():
        raise ValueError('input tensor is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    return ShardedTensor._init_from_local_shards(local_shards,
                                                 tensor.size(),
                                                 process_group=pg)
示例#18
0
def eval(
    model: torch.nn.Module,
    eval_loader: torch.utils.data.DataLoader,
    is_main: bool,
    previous_best: dict,
    img_size: List[int],
    save_dir: pathlib.Path = None,
) -> Tuple[dict, List[str]]:
    """ Evalulate the model against the evaulation set. Save the best weights if
    specified. Use the pycocotools package for metrics.

    Args:
        model: The model to evaluate.
        eval_loader: The eval dataset loader.
        previous_best: The current best eval metrics.
        save_dir: Where to save the model weights.

    Returns:
        The updated best metrics and a list of the metrics that improved.
    """
    detections = []
    labels = []
    for images_batch, category_ids_batch, boxes_batch in eval_loader:
        # Send ground truth to BoundingBox
        for boxes, categories in zip(boxes_batch, category_ids_batch):
            image_boxes = []
            for box, category in zip(boxes, categories.squeeze(0)):
                image_boxes.append(
                    pascal_voc.BoundingBox(box / torch.Tensor(img_size * 2),
                                           1.0,
                                           category.int().item()))

            labels.append(image_boxes)

        if torch.cuda.is_available():
            images_batch = images_batch.cuda()

        if isinstance(model, parallel.DistributedDataParallel):
            detection_batch = model.module.get_boxes(images_batch)
        else:
            detection_batch = model.get_boxes(images_batch)
        detections.extend(detection_batch)

    if torch.distributed.is_initialized():
        labels_list = [None] * distributed.get_world_size()
        detections_list = [None] * distributed.get_world_size()
        distributed.all_gather_object(detections_list, detections)
        distributed.all_gather_object(labels_list, labels)

    if is_main:

        if torch.distributed.is_initialized():
            labels = []
            for label_group in labels_list:
                labels.extend(label_group)
            detections = []
            for detections_group in detections_list:
                detections.extend(detections_group)

        if isinstance(model, parallel.DistributedDataParallel):
            num_classes = model.module.num_classes
        else:
            num_classes = model.num_classes

        metrics = pascal_voc.compute_metrics(detections,
                                             labels,
                                             class_ids=list(
                                                 range(num_classes)))

        # If there are the first results, set the previous to the current.
        previous_best = metrics if not previous_best else previous_best

        improved = []
        for (metric, old), new in zip(previous_best.items(), metrics.values()):
            if new >= old:
                improved.append(metric)
                previous_best[metric] = new

        return previous_best, improved
    else:
        return None, None
示例#19
0
def allgather_object(obj):
    out = [None for _ in range(dist.get_world_size())]
    dist.all_gather_object(out, obj)
    return out
示例#20
0
import torch.distributed as dist

if dist.get_rank() == 0:
    objects = ["f", 1]
else:
    objects = [None, None]

# ruleid: pickles-in-torch-distributed
dist.broadcast_object_list(objects, src=0)

# ruleid: pickles-in-torch-distributed
dist.all_gather_object(output, gather_objects[dist.get_rank()])

# ruleid: pickles-in-torch-distributed
dist.gather_object(gather_objects[dist.get_rank()],
                   output if dist.get_rank() == 0 else None,
                   dst=0)

# ruleid: pickles-in-torch-distributed
dist.scatter_object_list(output_list, objects, src=0)
示例#21
0
    def test_all_gather_object(self):
        output = [None] * dist.get_world_size()
        dist.all_gather_object(object_list=output, obj=self.rank)

        for i, v in enumerate(output):
            self.assertEqual(i, v, f"rank: {self.rank}")
示例#22
0
def shard_parameter(module: torch.nn.Module,
                    param_name: str,
                    sharding_spec: ShardingSpec,
                    src_rank=0,
                    process_group=None):
    """
    Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
    module, it shards that parameter according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    This method replaces ``module.param_name`` with a
    :class:`torch.distributed._sharded_tensor.ShardedTensor`

    Args:
        module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
        param_name (str): Name of the parameter of ``module`` that needs to be sharded.
        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    .. warning::
        Only :class:`torch.distributed._sharding_spec.ShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    # Perform some validation first.
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise ValueError('Only ChunkShardingspec is supported.')

    if not hasattr(module, param_name):
        raise ValueError(
            f'module: {module} does not have parameter with name: {param_name}'
        )

    tensor = getattr(module, param_name)
    if not isinstance(tensor, torch.Tensor):
        raise ValueError(
            f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}'
        )

    if not tensor.is_contiguous():
        raise ValueError(f'param: {param_name} is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    with torch.cuda.device(tensor.device):
        dist.all_gather_object(gathered_list, (src_rank, sharding_spec),
                               group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_lengths=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_lengths[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).contiguous()

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]
    sharded_tensor_metadata = ShardedTensorMetadata(
        shards_metadata=shards_metadata,
        size=tensor.size(),
        tensor_properties=TensorProperties(
            dtype=local_shard.dtype,
            layout=local_shard.layout,
            requires_grad=local_shard.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=local_shard.is_pinned(),
        ))

    st = ShardedTensor._init_from_local_shards(local_shards,
                                               sharded_tensor_metadata,
                                               process_group=pg)

    # Manually set sharding_spec
    st._sharding_spec = sharding_spec

    # Replace param with ShardedTensor.

    # Need to delete the attribute first since param_name might be
    # torch.nn.Parameter and can't be replaced with ShardedTensor which is
    # not torch.nn.Parameter.
    delattr(module, param_name)

    # Now we can set the attribute appropriately.
    setattr(module, param_name, st)
示例#23
0
def _tensorpipe_exchange_and_check_all_device_maps(
    my_name, my_device_count, my_device_maps, my_devices, group
):
    gathered: List[Tuple[
        str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
    ]] = [("", 0, {}, []) for _ in range(group.size())]
    dist.all_gather_object(
        gathered, (my_name, my_device_count, my_device_maps, my_devices), group
    )
    all_names = [name for name, _, _, _ in gathered]
    all_device_counts = {name: count for name, count, _, _ in gathered}
    all_device_maps = {name: map_ for name, _, map_, _ in gathered}
    all_devices = {name: devices for name, _, _, devices in gathered}

    for node in all_names:
        devices = all_devices[node]
        if len(set(devices)) != len(devices):
            raise ValueError(
                f"Node {node} has duplicated devices\n"
                f"devices = {devices}"
            )
        if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
            raise ValueError(
                f"Node {node} has devices with invalid indices\n"
                f"devices = {devices}\n"
                f"device count = {all_device_counts[node]}"
            )

    for source_node in all_names:
        if not set(all_device_maps[source_node].keys()).issubset(all_names):
            raise ValueError(
                f"Node {source_node} has invalid target node names in its device maps\n"
                f"device maps = {all_device_maps[source_node].keys()}\n"
                f"node names = {all_names}"
            )
        for target_node, map_ in all_device_maps[source_node].items():
            if len(set(map_.values())) != len(map_):
                raise ValueError(
                    f"Node {source_node} has duplicated target devices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}"
                )
            if all_devices[source_node]:
                if not set(map_.keys()).issubset(all_devices[source_node]):
                    raise ValueError(
                        f"Node {source_node} has unexpected source devices "
                        f"in its device map for {target_node}\n"
                        f"device map = {map_}\n"
                        f"devices = {all_devices[source_node]}"
                    )
            elif not _tensorpipe_validate_devices(
                map_.keys(), all_device_counts[source_node]
            ):
                raise ValueError(
                    f"Node {source_node} has source devices with invalid indices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}\n"
                    f"device count = {all_device_counts[source_node]}"
                )
            if all_devices[target_node]:
                if not set(map_.values()).issubset(all_devices[target_node]):
                    raise ValueError(
                        f"Node {source_node} has unexpected target devices "
                        f"in its device map for {target_node}\n"
                        f"device map = {map_}\n"
                        f"devices = {all_devices[target_node]}"
                    )
            elif not _tensorpipe_validate_devices(
                map_.values(), all_device_counts[target_node]
            ):
                raise ValueError(
                    f"Node {source_node} has target devices with invalid indices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}\n"
                    f"device count = {all_device_counts[target_node]}"
                )

    # passed all checked, construct reverse mapping for return values
    reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
    for node in all_names:
        if my_name in all_device_maps[node]:
            reverse_device_maps[node] = {
                v: k for k, v in all_device_maps[node][my_name].items()
            }

    if not my_devices:
        devices_set: Set[torch.device] = set()
        for _, map_ in my_device_maps.items():
            devices_set.update(map_.keys())
        for _, map_ in reverse_device_maps.items():
            devices_set.update(map_.keys())
        devices_set.discard(torch.device("cpu"))
        my_devices = list(devices_set)
    my_devices = sorted(my_devices, key=lambda d: d.index)

    return reverse_device_maps, my_devices
示例#24
0
    def eval_epoch_end(self, outputs, mode):
        # if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary
        if isinstance(outputs[0], dict):
            outputs = [outputs]

        loss_list = []
        sb_score_list = []
        for dataloader_idx, output in enumerate(outputs):
            if dataloader_idx == 0:
                eval_loss = getattr(self, f'{mode}_loss').compute()
            else:
                eval_loss = getattr(self,
                                    f'{mode}_loss_{dataloader_idx}').compute()

            translations = list(
                itertools.chain(*[x['translations'] for x in output]))
            ground_truths = list(
                itertools.chain(*[x['ground_truths'] for x in output]))
            assert len(translations) == len(ground_truths)

            # Gather translations and ground truths from all workers
            tr_and_gt = [None for _ in range(self.world_size)]
            # we also need to drop pairs where ground truth is an empty string
            dist.all_gather_object(
                tr_and_gt, [(t, g)
                            for (t, g) in zip(translations, ground_truths)
                            if g.strip() != ''])
            if self.global_rank == 0:
                _translations = []
                _ground_truths = []
                for rank in range(0, self.world_size):
                    _translations += [t for (t, g) in tr_and_gt[rank]]
                    _ground_truths += [g for (t, g) in tr_and_gt[rank]]

                if self.tgt_language in ['ja']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="ja-mecab")
                elif self.tgt_language in ['zh']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="zh")
                else:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="13a")

                # because the reduction op later is average (over word_size)
                sb_score = sacre_bleu.score * self.world_size

                dataset_name = "Validation" if mode == 'val' else "Test"
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(translations)}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Val Loss = {eval_loss}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Sacre BLEU = {sb_score / self.world_size}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:"
                )
                for i in range(0, 3):
                    ind = random.randint(0, len(translations) - 1)
                    logging.info("    " + '\u0332'.join(f"Example {i}:"))
                    logging.info(f"    Prediction:   {translations[ind]}")
                    logging.info(f"    Ground Truth: {ground_truths[ind]}")
            else:
                sb_score = 0.0

            loss_list.append(eval_loss.cpu().numpy())
            sb_score_list.append(sb_score)
            if dataloader_idx == 0:
                self.log(f"{mode}_loss", eval_loss, sync_dist=True)
                self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True)
                getattr(self, f'{mode}_loss').reset()
            else:
                self.log(f"{mode}_loss_dl_index_{dataloader_idx}",
                         eval_loss,
                         sync_dist=True)
                self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}",
                         sb_score,
                         sync_dist=True)
                getattr(self, f'{mode}_loss_{dataloader_idx}').reset()

        if len(loss_list) > 1:
            self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True)
            self.log(f"{mode}_sacreBLEU_avg",
                     np.mean(sb_score_list),
                     sync_dist=True)
示例#25
0
def load_state_dict(
    state_dict: Dict[str, Any],
    storage_reader: StorageReader,
    process_group: Optional[dist.ProcessGroup] = None,
    coordinator_rank: int = 0,
    no_dist: bool = False
) -> None:
    """
    Load a distributed state_dict in SPMD style.

    Each rank will try to read the least amount of data necessary
    to fullfill the requested `state_dict`.

    When loading ShardedTensor instances, each rank only
    reads data for their local shards.

    All tensors in ``state_dict`` must be allocated on their
    destination device prior to calling this function.

    All non-tensor data is loaded using `torch.load()` and modified in place
    on state_dict.

    Users must call `load_state_dict` on the root module to ensure load
    pos-processing and non-tensor data properly propagates.

    This function can be used for local inference and load a checkpoint
    produced by ``save_state_dict`` without having a process group initialized
    by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors.

    Args:
        state_dict (Dict[str, Any]) : The state_dict to load. Note that this
            state dict will updated in places.
        storage_reader (StorageReader): StorageReader used to load data from.
        process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization
        coordinator_rank (int): Rank to use to coordinate the checkpoint, rank0 is used by default
        no_dist (bool): Don't attempt to load in SPMD style. Default to False

    Returns:
        None.

    Examples
        >>> my_model = MyModule()
        >>> optimizer = Adagrad(my_model.parameters())
        >>> model_state_dict = my_model.state_dict()
        >>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1")

        >>> torch.distributed._shard.checkpoint.load_state_dict(
        >>>     state_dict=model_state_dict,
        >>>     storage_reader=fs_storage_loader,
        >>> )

        >>> # module.load_state_dict() function might have customized steps
        >>> # to flush the state_dict, must call it to
        >>> # ensure correct behavior.
        >>> my_model.load_state_dict(model_state_dict)

    .. note:: load_state_dict uses collectives to coordinate reads across ranks.
        For NCCL-based process groups, internal tensor representations of objects
        must be moved to the GPU device before communication takes place. In this
        case, the device used is given by ``torch.cuda.current_device()`` and it
        is the user's responsibility to ensure that this is set so that each rank
        has an individual GPU, via ``torch.cuda.set_device()``
    """
    try:
        metadata = storage_reader.read_metadata()
        bytes_read_requests, tensor_read_requests = _reshard_and_prepare_read_request(
            state_dict=state_dict, metadata_from_storage=metadata
        )
        bytes_futures = storage_reader.read_bytes(bytes_read_requests)
        tensor_futures = storage_reader.read_tensors(tensor_read_requests)

        bytes_futures.wait()

        # Addtional steps are required to convert the bytes to its original type
        # Note that this is NOT inplace,
        # it creating a new object and replace what's in the state dict
        for req in bytes_read_requests:
            # Ensure the BytesIO is rewound
            req.bytes.seek(0)
            state_dict[req.fqn] = torch.load(req.bytes)

        tensor_futures.wait()
        result = None
    except BaseException as e:
        result = e

    global_result: Optional[CheckpointException] = None
    if not no_dist:
        all_errors = [None] * dist.get_world_size(process_group)

        dist.all_gather_object(
            object_list=all_errors,
            obj=result,
            group=process_group)

        node_failures = cast(Dict[int, BaseException], {i: err for i, err in enumerate(all_errors) if err is not None})
        if len(node_failures) > 0:
            global_result = CheckpointException("failed to read checkpoint", node_failures)
    elif result is not None:
        global_result = CheckpointException("failed to read storage", {coordinator_rank : result})

    if global_result is not None:
        raise global_result