Example #1
0
    def test_collectives(self):
        dist.Backend.register_backend("dummy", PythonProcessGroupTest.create_dummy)

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '6789'
        dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)

        # test all_gather
        input_tensor = torch.ones(2, 2) * 7
        output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)]
        dist.all_gather(output_tensor_list, input_tensor)

        for tensor in output_tensor_list:
            self.assertEqual(tensor, input_tensor)

        # test all_reduce
        input_tensor = torch.ones(2, 2) * 7
        dist.all_reduce(input_tensor)
        self.assertEqual(input_tensor, torch.ones(2, 2) * 7 + 2)

        # test broadcast
        input_tensor = torch.zeros(2, 2)
        dist.broadcast(input_tensor, 0, async_op=True).wait()
        self.assertEqual(torch.ones(2, 2), input_tensor)

        # test reduce_scatter
        output_tensor = torch.zeros(2, 2)
        input_tensor_list = [torch.ones(2, 2) for _ in range(self.world_size)]
        dist.reduce_scatter(output_tensor, input_tensor_list)
        self.assertEqual(output_tensor, torch.zeros(2, 2) + 1)

        dist.destroy_process_group()
Example #2
0
 def forward(ctx, op, group, tensor, *input_tensor_list):
     ctx.group = group
     dist.reduce_scatter(tensor,
                         list(input_tensor_list),
                         op=op,
                         group=group)
     return tensor
 def _run_partial_tensor_n_reshard(self,
                                   reshard_spec,
                                   input_size,
                                   world_size,
                                   reduce_op,
                                   dtype=torch.float):
     results = []
     results_compare = []
     for _ in range(0, world_size):
         tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank)
         results.append(tensor)
         results_compare.append(tensor.clone().detach())
     pg = dist.distributed_c10d._get_default_group()
     parital_tensor = _PartialTensor(torch.cat(results),
                                     pg,
                                     reduce_op=reduce_op)
     local_sharded_result = parital_tensor.reshard(reshard_spec)
     local_shards = local_sharded_result.local_shards()
     if pg.size() > world_size:
         chunk_mode_res = (input_size[0] * world_size) % pg.size()
         padding = [0] * (len(input_size) * 2)
         padding[-1] = pg.size() - chunk_mode_res
         results_compare = list(
             torch.nn.functional.pad(
                 torch.cat(results_compare),
                 tuple(padding),
                 "constant",
                 0,
             ).chunk(pg.size()))
     local_result_compare = torch.empty_like(results_compare[0])
     dist.reduce_scatter(local_result_compare,
                         results_compare,
                         op=reduce_op)
     self.assertEqual(1, len(local_shards))
     self.assertEqual(local_shards[0].tensor, local_result_compare)
Example #4
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input,
                           input_t,
                           input_split_sizes=input_split_sizes,
                           group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Example #5
0
 def forward(ctx, op, group, tensor, *input_tensor_list):
     ctx.group = group
     input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
     dist.reduce_scatter(tensor,
                         list(input_tensor_list),
                         op=op,
                         group=group)
     return tensor
Example #6
0
 def backward(ctx, *grads):
     world_size = dist.get_world_size()
     rank = dist.get_rank()
     grad_list = list(grads)
     grad_out = torch.zeros_like(grad_list[rank], requires_grad=True)
     dist.reduce_scatter(grad_out, grad_list, op=ReduceOp.SUM)
     # Gradient correction for DistCrossEntropy
     grad_out = grad_out * world_size
     return (grad_out, None)
Example #7
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[int] = []
        for placement in weight._sharding_spec.placements:
            sharded_dim_size = get_chunked_dim_size(input_t_size[0],
                                                    split_size,
                                                    placement.rank())
            input_idx = placement.rank() * split_size
            indices += range(input_idx, input_idx + sharded_dim_size)

        input_t = input_t.index_select(
            0, torch.tensor(indices, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input,
                           input_t,
                           input_split_sizes=input_split_sizes,
                           group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
    def forward_backward(self, label, features, optimizer):
        total_label, norm_weight = self.prepare(label, optimizer)
        total_features = torch.zeros(
            size=[self.batch_size * self.world_size, self.embedding_size],
            device=self.device)
        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)),
                        features.data)
        total_features.requires_grad = True

        logits = self.forward(total_features, norm_weight)
        logits = self.margin_softmax(logits, total_label)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0],
                                        grad.size()[1]],
                                  device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # calculate grad
            grad[index] -= one_hot
            grad.div_(self.batch_size * self.world_size)

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad: torch.Tensor = torch.zeros_like(features)
        x_grad.mul_(self.world_size)

        # feature gradient all-reduce
        dist.reduce_scatter(
            x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
        # backward backbone
        return x_grad, loss_v
Example #9
0
def _torch_reduce_scatter(output_tensors, input_tensors, partition_sizes, rank,
                          world_size):
    """"""
    for part_idx, part_size in enumerate(partition_sizes):
        output_t = output_tensors[part_idx]
        input_t = input_tensors[part_idx]

        input_list = []
        for i in range(world_size):
            _input = input_t.narrow(0, i * part_size, part_size)
            input_list.append(_input)

        dist.reduce_scatter(output_t, input_list)

    torch.cuda.synchronize()
 def flush(self) -> None:
     if self.offset == 0:
         assert len(self.callbacks) == 0
         return
     # reduce-scatter bucket
     dist.reduce_scatter(self.output_shard[:self.offset],
                         list(self.data[:, :self.offset].unbind(0)),
                         group=self.group)
     # execute post-reduction callbacks
     for callback_fn in self.callbacks:
         callback_fn()
     # reuse input bucket but allocate a fresh output shard
     self.data[:, :self.offset].zero_()
     self.offset = 0
     self.callbacks.clear()
     self.output_shard = torch.zeros_like(self.data[0])
Example #11
0
 def reduce_scatter_fn(output_tensor,
                       input_tensor,
                       group=None,
                       async_op=False):
     from torch.distributed import reduce_scatter, get_world_size
     from torch import chunk
     input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
     return reduce_scatter(output_tensor, input_tensor_lst, group=group)
Example #12
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) in ('TPU', 'GPU'):
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()

        dist.init_process_group('xla', world_size=world_size, rank=rank)

        input_size = (32, 3)
        inputs = torch.ones(input_size).split(input_size[0] // world_size)
        output = torch.zeros_like(inputs[0])
        xinputs = [i.to(device) for i in inputs]
        xoutput = output.to(device)
        dist.reduce_scatter(xoutput, xinputs)
        expected = torch.ones_like(inputs[0]) * world_size
        assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}'
    else:
        print('Default device {} is not a TPU or GPU device'.format(device),
              file=sys.stderr)
    def reduce_scatter(self, collectiveArgs, retFlag=False, pair=False):
        retObj = dist.reduce_scatter(
            output=collectiveArgs.opTensor,
            input_list=collectiveArgs.ipTensor,
            group=collectiveArgs.group,
            async_op=collectiveArgs.asyncOp,
        )  # synchronicity is maintained in runColl

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(retObj)

        if retFlag:
            return retObj
 def flush(self) -> None:
     """Flush content of the bucket."""
     if self.offset == 0:
         assert len(self.callbacks) == 0
         return
     # reduce-scatter bucket
     if hasattr(dist, "_reduce_scatter_base"):
         dist._reduce_scatter_base(  # type: ignore
             self.output_shard[:self.offset],
             self.data[:, :self.offset].contiguous(),
             group=self.group)
     else:
         dist.reduce_scatter(self.output_shard[:self.offset],
                             list(self.data[:, :self.offset].unbind(0)),
                             group=self.group)
     # execute post-reduction callbacks
     for callback_fn in self.callbacks:
         callback_fn()
     # reuse input bucket but allocate a fresh output shard
     self.data[:, :self.offset].zero_()
     self.offset = 0
     self.callbacks.clear()
     self.output_shard = torch.zeros_like(self.data[0])
Example #15
0
 def _run_partial_tensor_n_reshard(self,
                                   reshard_spec,
                                   input_size,
                                   world_size,
                                   reduce_op,
                                   dtype=torch.float):
     results = []
     results_compare = []
     for _ in range(0, world_size):
         tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank)
         results.append(tensor)
         results_compare.append(tensor.clone().detach())
     pg = dist.distributed_c10d._get_default_group()
     parital_tensor = _PartialTensor(torch.cat(results),
                                     pg,
                                     reduce_op=reduce_op)
     local_sharded_result = parital_tensor.reshard(reshard_spec)
     local_shards = local_sharded_result.local_shards()
     local_result_compare = torch.empty_like(results_compare[0])
     dist.reduce_scatter(local_result_compare,
                         results_compare,
                         op=reduce_op)
     self.assertEqual(1, len(local_shards))
     self.assertEqual(local_shards[0].tensor, local_result_compare)
Example #16
0
def main(local_rank):
    dist.init_process_group(backend='nccl', init_method='env://')
    cfg.local_rank = local_rank
    torch.cuda.set_device(local_rank)
    cfg.rank = dist.get_rank()
    cfg.world_size = dist.get_world_size()
    trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoaderX(local_rank=local_rank,
                               dataset=trainset,
                               batch_size=cfg.batch_size,
                               sampler=train_sampler,
                               num_workers=0,
                               pin_memory=True,
                               drop_last=False)

    backbone = backbones.iresnet100(False).to(local_rank)
    backbone.train()

    # Broadcast init parameters
    for ps in backbone.parameters():
        dist.broadcast(ps, 0)

    # DDP
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[cfg.local_rank])
    backbone.train()

    # Memory classifer
    dist_sample_classifer = DistSampleClassifier(rank=dist.get_rank(),
                                                 local_rank=local_rank,
                                                 world_size=cfg.world_size)

    # Margin softmax
    margin_softmax = MarginSoftmax(s=64.0, m=0.4)

    # Optimizer for backbone and classifer
    optimizer = SGD([{
        'params': backbone.parameters()
    }, {
        'params': dist_sample_classifer.parameters()
    }],
                    lr=cfg.lr,
                    momentum=0.9,
                    weight_decay=cfg.weight_decay,
                    rescale=cfg.world_size)

    # Lr scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                  lr_lambda=cfg.lr_func)
    n_epochs = cfg.num_epoch
    start_epoch = 0

    if local_rank == 0:
        writer = SummaryWriter(log_dir='logs/shows')

    #
    total_step = int(
        len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
    if dist.get_rank() == 0:
        print("Total Step is: %d" % total_step)

    losses = AverageMeter()
    global_step = 0
    train_start = time.time()
    for epoch in range(start_epoch, n_epochs):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            total_label, norm_weight = dist_sample_classifer.prepare(
                label, optimizer)
            features = F.normalize(backbone(img))

            # Features all-gather
            total_features = torch.zeros(features.size()[0] * cfg.world_size,
                                         cfg.embedding_size,
                                         device=local_rank)
            dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),
                            features.data)
            total_features.requires_grad = True

            # Calculate logits
            logits = dist_sample_classifer(total_features, norm_weight)
            logits = margin_softmax(logits, total_label)

            with torch.no_grad():
                max_fc = torch.max(logits, dim=1, keepdim=True)[0]
                dist.all_reduce(max_fc, dist.ReduceOp.MAX)

                # Calculate exp(logits) and all-reduce
                logits_exp = torch.exp(logits - max_fc)
                logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
                dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

                # Calculate prob
                logits_exp.div_(logits_sum_exp)

                # Get one-hot
                grad = logits_exp
                index = torch.where(total_label != -1)[0]
                one_hot = torch.zeros(index.size()[0],
                                      grad.size()[1],
                                      device=grad.device)
                one_hot.scatter_(1, total_label[index, None], 1)

                # Calculate loss
                loss = torch.zeros(grad.size()[0], 1, device=grad.device)
                loss[index] = grad[index].gather(1, total_label[index, None])
                dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

                # Calculate grad
                grad[index] -= one_hot
                grad.div_(features.size()[0])

            logits.backward(grad)
            if total_features.grad is not None:
                total_features.grad.detach_()
            x_grad = torch.zeros_like(features)

            # Feature gradient all-reduce
            dist.reduce_scatter(
                x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
            x_grad.mul_(cfg.world_size)
            # Backward backbone
            features.backward(x_grad)
            optimizer.step()

            # Update classifer
            dist_sample_classifer.update()
            optimizer.zero_grad()
            losses.update(loss_v, 1)
            if cfg.local_rank == 0 and step % 50 == 0:
                time_now = (time.time() - train_start) / 3600
                time_total = time_now / ((global_step + 1) / total_step)
                time_for_end = time_total - time_now
                writer.add_scalar('time_for_end', time_for_end, global_step)
                writer.add_scalar('loss', loss_v, global_step)
                print(
                    "Speed %d samples/sec   Loss %.4f   Epoch: %d   Global Step: %d   Required: %1.f hours"
                    % ((cfg.batch_size * global_step /
                        (time.time() - train_start) * cfg.world_size),
                       losses.avg, epoch, global_step, time_for_end))
                losses.reset()

            global_step += 1
        scheduler.step()
        if dist.get_rank() == 0:
            import os
            if not os.path.exists(cfg.output):
                os.makedirs(cfg.output)
            torch.save(backbone.module.state_dict(),
                       os.path.join(cfg.output,
                                    str(epoch) + 'backbone.pth'))
    dist.destroy_process_group()
Example #17
0
    def reduce_scatter_gradients(self, postscale_gradients,
                                 gradient_predivide_factor, gradient_average):
        world_size = dist.get_world_size(group=self.dp_process_group)
        local_rank = dist.get_rank(group=self.dp_process_group)

        for i, group in enumerate(self.fp16_groups):
            partition_param_map = {}
            param_partition_map = {}
            my_params = set()

            # [rank] -> [comm] -> partition
            num_comm_intervals = self.num_comm_intervals_per_group[i]
            all_sub_partitions = []
            for rank in range(world_size):
                # gsp is list of partitions indexed by comm_idx
                #FIXME: currently hardcoding fp16, should infer dtype
                grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
                    comm_tensor_list=self.params_in_rank_sub_partitions[i]
                    [rank],
                    comm_param_offsets=self.
                    params_in_rank_sub_partitions_offsets[i][rank],
                    sub_partition_size=self.sub_partition_sizes[i],
                    dtype=torch.
                    half,  #self.params_in_rank_sub_partitions[i][rank][0][0].dtype,
                    num_comm_intervals=self.num_comm_intervals_per_group[i],
                    default_device=
                    'cuda',  #self.params_in_rank_sub_partitions[i][rank][0][0].device,
                    return_partition_params=True)
                all_sub_partitions.append(grad_sub_partitions)

                # create map from partition -> params in that partition
                for comm_idx, part in enumerate(grad_sub_partitions):
                    partition_param_map[part] = (partition_params[comm_idx],
                                                 param_offsets[comm_idx])

                for comm_idx, params in enumerate(partition_params):
                    for pidx, p in enumerate(params):
                        # store the parameters we care about locally
                        if rank == local_rank:
                            my_params.add(p)
                        # map from param -> partitions
                        if p in param_partition_map:
                            param_partition_map[p].append(
                                grad_sub_partitions[comm_idx])
                        else:
                            param_partition_map[p] = [
                                grad_sub_partitions[comm_idx]
                            ]

                assert len(grad_sub_partitions) == num_comm_intervals

            if not postscale_gradients:
                raise NotImplementedError(
                    "pre-scale_gradients is not implemented")

            all_comm_partitions = []
            for comm_idx in range(num_comm_intervals):
                single_comm_all_partitions = []
                for rank in range(world_size):
                    single_comm_all_partitions.append(
                        all_sub_partitions[rank][comm_idx])
                dist.reduce_scatter(
                    output=single_comm_all_partitions[local_rank],
                    input_list=single_comm_all_partitions,
                    group=self.dp_process_group)

                if gradient_average:
                    for partition in single_comm_all_partitions:
                        partition.mul_(gradient_predivide_factor / world_size)

                all_comm_partitions.append(single_comm_all_partitions)

            for p in my_params:
                partitions = param_partition_map[p]
                parts = []
                for part in partitions:
                    params, offsets = partition_param_map[part]
                    found = False
                    for p_idx, _p in enumerate(params):
                        if p.__hash__() == _p.__hash__():
                            found = True
                            if offsets[p_idx][0] is not None:
                                my_part = part.narrow(0, offsets[p_idx][0],
                                                      offsets[p_idx][1])
                                parts.append(my_part)
                    assert found
                if p is not None:
                    updated_grad = _unflatten_dense_tensors(
                        torch.cat(parts), [p])
                    p.grad.copy_(updated_grad[0])
Example #18
0
def main(local_rank):
    cfg.local_rank = local_rank
    # cfg.rank = dist.get_rank()
    # cfg.world_size = dist.get_world_size()

    backbone = backbones.iresnet50(False)

    weights = torch.load("pytorch/partial_fc_glint360k_r50/16backbone.pth",
                         map_location=torch.device('cpu'))
    backbone.load_state_dict(weights)
    backbone = backbone.float()
    backbone = backbone.eval()

    # embedding 512

    img1 = cv2.imread('boy_1.jpg')
    img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
    img1 = image_preprocessing(img1)
    img1 = np.ones([112, 112, 3], dtype=np.float32)
    img1 = img1.transpose([2, 0, 1])
    img1 = np.expand_dims(img1, axis=0)
    img1 = torch.from_numpy(img1).float()
    img1 = torch.autograd.Variable(img1, requires_grad=False).to('cpu')

    img2 = cv2.imread('man_2.jpg')
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    img2 = image_preprocessing(img2)
    img2 = img2.transpose([2, 0, 1])
    img2 = np.expand_dims(img2, axis=0)
    img2 = torch.from_numpy(img2).float()
    img2 = torch.autograd.Variable(img2, requires_grad=False).to('cpu')
    with torch.no_grad():
        v1 = backbone.forward(img1)
        v2 = backbone.forward(img2)

    v1 = np.asarray(v1)
    import pickle
    pickle.dump(v1, open("sample.pkl", "wb"))
    print(v1)

    result = cosine_dist(v1, v2)
    print(result)

    exit(0)
    # Broadcast init parameters
    for ps in backbone.parameters():
        dist.broadcast(ps, 0)

    # DDP
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[cfg.local_rank])
    backbone.train()

    # Memory classifer
    dist_sample_classifer = DistSampleClassifier(rank=dist.get_rank(),
                                                 local_rank=local_rank,
                                                 world_size=cfg.world_size)

    # Margin softmax
    margin_softmax = MarginSoftmax(s=64.0, m=0.4)

    # Optimizer for backbone and classifer
    optimizer = SGD([{
        'params': backbone.parameters()
    }, {
        'params': dist_sample_classifer.parameters()
    }],
                    lr=cfg.lr,
                    momentum=0.9,
                    weight_decay=cfg.weight_decay,
                    rescale=cfg.world_size)

    # Lr scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                  lr_lambda=cfg.lr_func)
    n_epochs = cfg.num_epoch
    start_epoch = 0

    if local_rank == 0:
        writer = SummaryWriter(log_dir='logs/shows')

    #
    total_step = int(
        len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
    if dist.get_rank() == 0:
        print("Total Step is: %d" % total_step)

    losses = AverageMeter()
    global_step = 0
    train_start = time.time()
    for epoch in range(start_epoch, n_epochs):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            total_label, norm_weight = dist_sample_classifer.prepare(
                label, optimizer)
            features = F.normalize(backbone(img))

            # Features all-gather
            total_features = torch.zeros(features.size()[0] * cfg.world_size,
                                         cfg.embedding_size,
                                         device=local_rank)
            dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),
                            features.data)
            total_features.requires_grad = True

            # Calculate logits
            logits = dist_sample_classifer(total_features, norm_weight)
            logits = margin_softmax(logits, total_label)

            with torch.no_grad():
                max_fc = torch.max(logits, dim=1, keepdim=True)[0]
                dist.all_reduce(max_fc, dist.ReduceOp.MAX)

                # Calculate exp(logits) and all-reduce
                logits_exp = torch.exp(logits - max_fc)
                logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
                dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

                # Calculate prob
                logits_exp.div_(logits_sum_exp)

                # Get one-hot
                grad = logits_exp
                index = torch.where(total_label != -1)[0]
                one_hot = torch.zeros(index.size()[0],
                                      grad.size()[1],
                                      device=grad.device)
                one_hot.scatter_(1, total_label[index, None], 1)

                # Calculate loss
                loss = torch.zeros(grad.size()[0], 1, device=grad.device)
                loss[index] = grad[index].gather(1, total_label[index, None])
                dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

                # Calculate grad
                grad[index] -= one_hot
                grad.div_(features.size()[0])

            logits.backward(grad)
            if total_features.grad is not None:
                total_features.grad.detach_()
            x_grad = torch.zeros_like(features)

            # Feature gradient all-reduce
            dist.reduce_scatter(
                x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
            x_grad.mul_(cfg.world_size)
            # Backward backbone
            features.backward(x_grad)
            optimizer.step()

            # Update classifer
            dist_sample_classifer.update()
            optimizer.zero_grad()
            losses.update(loss_v, 1)
            if cfg.local_rank == 0 and step % 50 == 0:
                time_now = (time.time() - train_start) / 3600
                time_total = time_now / ((global_step + 1) / total_step)
                time_for_end = time_total - time_now
                writer.add_scalar('time_for_end', time_for_end, global_step)
                writer.add_scalar('loss', loss_v, global_step)
                print(
                    "Speed %d samples/sec   Loss %.4f   Epoch: %d   Global Step: %d   Required: %1.f hours"
                    % ((cfg.batch_size * global_step /
                        (time.time() - train_start) * cfg.world_size),
                       losses.avg, epoch, global_step, time_for_end))
                losses.reset()

            global_step += 1
        scheduler.step()
        if dist.get_rank() == 0:
            import os
            if not os.path.exists(cfg.output):
                os.makedirs(cfg.output)
            torch.save(backbone.module.state_dict(),
                       os.path.join(cfg.output,
                                    str(epoch) + 'backbone.pth'))
    dist.destroy_process_group()
    def reduce_scatter_async(
        self,
        input_list: List[Tensor],
        group: ProcessGroup,
        callback_fn: Optional[Callable] = None,
    ) -> None:
        """
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
        can be bucketed together. The given callback (``callback_fn``) will be
        called with the reduced result at some later time. Call ``flush()`` to
        force all queued ops and callbacks to be executed.

        Note that large inputs will be reduced immediately, and this function
        may also flush the relevant bucket to make room for ``input_list``.

        Args:
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
                should contain ``group.size()`` tensors and each tensor should
                have identical shape, dtype and device.
            group (ProcessGroup): process group for reduction
            callback_fn (Callable, Optional): callback function to call after
                the reduction executes. Function will be called with a single
                argument corresponding to the reduced result.
        """
        world_size = group.size()

        assert (
            len(input_list) == world_size
        ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

        first_input = input_list[0]
        first_input_size = first_input.numel()

        bucket_shard_size = self._get_shard_size(first_input.element_size(),
                                                 world_size)
        if first_input_size > bucket_shard_size:
            # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
            # input is too big to fit in the bucket, reduce-scatter directly
            output = torch.zeros_like(input_list[0])
            if hasattr(dist, "_reduce_scatter_base"):
                input_flattened = torch.cat(input_list)
                dist._reduce_scatter_base(output, input_flattened,
                                          group=group)  # type: ignore
            else:
                # fallback
                dist.reduce_scatter(output, input_list, group=group)
            if callback_fn is not None:
                callback_fn(output)
            return

        bucket = self._get_bucket(first_input, group)
        if first_input_size > bucket.data.size(1) - bucket.offset:
            # not enough space remaining in bucket, flush it now
            bucket.flush()

        # copy data from input_list into bucket
        stacked_input = torch.stack(input_list).view(world_size,
                                                     first_input_size)
        offset = bucket.offset
        bucket.data[:, offset:offset + first_input_size].copy_(stacked_input)
        bucket.offset += first_input_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = bucket.output_shard[offset:offset +
                                              first_input_size].view_as(
                                                  first_input)
            bucket.callbacks.append(functools.partial(callback_fn,
                                                      result_view))
Example #20
0
def main(local_rank, world_size, init_method='tcp://127.0.0.1:23499'):
    dist.init_process_group(backend='nccl',
                            init_method=init_method,
                            rank=local_rank,
                            world_size=world_size)
    cfg.local_rank = local_rank
    torch.cuda.set_device(local_rank)
    cfg.rank = dist.get_rank()
    cfg.world_size = world_size
    print(cfg.rank, dist.get_world_size())
    trainset = MXFaceDataset(root_dir='/root/face_datasets/webface/',
                             local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    trainloader = DataLoaderX(local_rank=local_rank,
                              dataset=trainset,
                              batch_size=cfg.batch_size,
                              sampler=train_sampler,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False)
    backbone = iresnet50(False).to(cfg.local_rank)
    backbone.train()
    # backbone = nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
    for ps in backbone.parameters():
        dist.broadcast(ps, 0)

    backbone = torch.nn.parallel.DistributedDataParallel(
        backbone, broadcast_buffers=False, device_ids=[dist.get_rank()])
    backbone.train()
    sub_start, sub_classnum = get_sub_class(cfg.rank, dist.get_world_size())
    print(sub_start, sub_classnum)
    classifier_head = classifier(cfg.embedding_size,
                                 sub_classnum,
                                 sample_rate=0.4)
    cosface = CosFace(s=64.0, m=0.4)
    optimizer = SGD([{
        'params': backbone.parameters()
    }, {
        'params': classifier_head.parameters()
    }],
                    0.1,
                    momentum=0.9,
                    weight_decay=cfg.weight_decay,
                    rescale=cfg.world_size)
    warm_up_with_multistep_lr = lambda epoch: (
        (epoch + 1) / (4 + 1))**2 if epoch < -1 else 0.1**len(
            [m for m in [20, 29] if m - 1 <= epoch])
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=warm_up_with_multistep_lr)
    n_epochs = 33
    start_epoch = 0

    if cfg.local_rank == 0:
        writer = SummaryWriter(log_dir='logs/shows')
    global_step = 0
    loss_fun = nn.CrossEntropyLoss()
    for epoch in range(start_epoch, n_epochs):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(trainloader):
            start = time.time()
            lable_gather, norm_weight = classifier_head.prepare(
                label, optimizer)
            x = F.normalize(backbone(img))
            x_gather = torch.zeros(x.size()[0] * cfg.world_size,
                                   cfg.embedding_size,
                                   device=cfg.local_rank)
            dist.all_gather(list(x_gather.chunk(cfg.world_size, dim=0)),
                            x.data)
            x_gather.requires_grad = True

            logits = classifier_head(x_gather, norm_weight)

            logits = cosface(logits, lable_gather)

            with torch.no_grad():
                max_v = torch.max(logits, dim=1, keepdim=True)[0]
                dist.all_reduce(max_v, dist.ReduceOp.MAX)
                exp = torch.exp(logits - max_v)
                sum_exp = exp.sum(dim=1, keepdims=True)
                dist.all_reduce(sum_exp, dist.ReduceOp.SUM)
                exp.div_(sum_exp.clamp_min(1e-20))
                grad = exp
                index = torch.where(lable_gather != -1)[0]
                one_hot = torch.zeros(index.size()[0],
                                      grad.size()[1],
                                      device=grad.device)
                one_hot.scatter_(1, lable_gather[index, None], 1)

                loss = torch.zeros(grad.size()[0], 1, device=grad.device)
                loss[index] = grad[index].gather(1, lable_gather[index, None])
                dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss_v = loss.clamp_min_(1e-20).log_().mean() * (-1)

                grad[index] -= one_hot
                grad.div_(grad.size()[0])

            logits.backward(grad)
            if x_gather.grad is not None:
                x_gather.grad.detach_()
            x_grad = torch.zeros_like(x)
            dist.reduce_scatter(
                x_grad, list(x_gather.grad.chunk(cfg.world_size, dim=0)))
            x.backward(x_grad)
            optimizer.step()
            classifier_head.update()
            optimizer.zero_grad()
            if cfg.rank == 0:
                print(x_gather.grad.max(), x_gather.grad.min())
                print('loss_v', loss_v.item(), global_step)
                writer.add_scalar('loss', loss_v, global_step)
                print('lr',
                      optimizer.state_dict()['param_groups'][0]['lr'],
                      global_step)
                print(cfg.batch_size / (time.time() - start))

            global_step += 1
        scheduler.step()
        if cfg.rank == 0:
            torch.save(backbone.module.state_dict(),
                       "models/" + str(epoch) + 'backbone.pth')
    dist.destroy_process_group()
Example #21
0
    def forward_backward(self, label, features, optimizer, feature_w):
        self._iters += 1
        total_label, norm_weight, index_positive = self.prepare(
            label, optimizer)
        total_features = torch.zeros(
            size=[self.batch_size * self.world_size, self.embedding_size],
            device=self.device)
        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)),
                        features.data)
        total_features.requires_grad = True

        if feature_w is not None:
            total_feature_w = torch.zeros(
                size=[self.batch_size * self.world_size, self.embedding_size],
                device=self.device)
            dist.all_gather(
                list(total_feature_w.chunk(self.world_size, dim=0)),
                feature_w.data)

        if self.vpl_mode >= 0:
            self.prepare_queue_lambda(total_label, self._iters)
            _lambda = self.queue_lambda.view(self.num_local, 1)
            injected_weight = norm_weight * (1.0 -
                                             _lambda) + self.queue * _lambda
            injected_norm_weight = normalize(injected_weight)
            logits = self.forward(total_features, injected_norm_weight)
        else:
            logits = self.forward(total_features, norm_weight)

        logits = self.margin_softmax(logits, total_label)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0],
                                        grad.size()[1]],
                                  device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # calculate grad
            grad[index] -= one_hot
            grad.div_(self.batch_size * self.world_size)

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
        # feature gradient all-reduce
        dist.reduce_scatter(
            x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
        x_grad = x_grad * self.world_size
        #vpl set queue
        if self.vpl_mode >= 0:
            if feature_w is None:
                self.set_queue(total_features.detach(), total_label,
                               index_positive, self._iters)
            else:
                self.set_queue(total_feature_w, total_label, index_positive,
                               self._iters)
        # backward backbone
        return x_grad, loss_v
Example #22
0
    def forward_backward(self, label, features, optimizer, x):
        """
        Partial fc forward and backward with model parallel

        label: tensor
            Label tensor on each rank(GPU)
        features: tensor
            Features tensor on each rank(GPU)
        optimizer: optimizer
            Optimizer for partial fc

        Returns:
        --------
        x_grad: tensor
            The gradient of features.
        loss_v: tensor
            Loss value for cross entropy.
        """
        total_label, norm_weight = self.prepare(label, optimizer)
        total_features = torch.zeros(
            size=[self.batch_size * self.world_size, self.embedding_size],
            device=self.device)
        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)),
                        features.data)
        total_features.requires_grad = True

        logits = self.forward(total_features, norm_weight)
        #print(logits.size())
        logits, g = self.margin_softmax(logits, total_label, x)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0],
                                        grad.size()[1]],
                                  device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + g

            # calculate grad
            grad[index] -= one_hot
            grad.div_(self.batch_size * self.world_size)

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
        # feature gradient all-reduce
        dist.reduce_scatter(
            x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
        x_grad = x_grad * self.world_size
        # backward backbone
        return x_grad, loss_v
Example #23
0
    def reduce_scatter_gradients(self, postscale_gradients,
                                 gradient_predivide_factor, gradient_average):
        world_size = dist.get_world_size(group=self.dp_process_group)
        local_rank = dist.get_rank(group=self.dp_process_group)

        for i, group in enumerate(self.fp16_groups):
            partition_param_map = {}
            param_partition_map = {}
            my_params = set()

            # [rank] -> [comm] -> partition
            num_comm_intervals = self.num_comm_intervals_per_group[i]
            all_sub_partitions = []
            for rank in range(world_size):
                # gsp is list of partitions indexed by comm_idx
                #FIXME: currently hardcoding fp16, should infer dtype
                grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
                    comm_tensor_list=self.params_in_rank_sub_partitions[i]
                    [rank],
                    comm_param_offsets=self.
                    params_in_rank_sub_partitions_offsets[i][rank],
                    sub_partition_size=self.sub_partition_sizes[i],
                    dtype=torch.
                    half,  #self.params_in_rank_sub_partitions[i][rank][0][0].dtype,
                    num_comm_intervals=self.num_comm_intervals_per_group[i],
                    default_device=
                    'cuda',  #self.params_in_rank_sub_partitions[i][rank][0][0].device,
                    return_partition_params=True)
                all_sub_partitions.append(grad_sub_partitions)

                # create map from partition -> params in that partition
                for comm_idx, part in enumerate(grad_sub_partitions):
                    partition_param_map[part] = (partition_params[comm_idx],
                                                 param_offsets[comm_idx])

                for comm_idx, params in enumerate(partition_params):
                    for pidx, p in enumerate(params):
                        # store the parameters we care about locally
                        if rank == local_rank:
                            my_params.add(p)
                        # map from param -> partitions
                        if p in param_partition_map:
                            param_partition_map[p].append(
                                grad_sub_partitions[comm_idx])
                        else:
                            param_partition_map[p] = [
                                grad_sub_partitions[comm_idx]
                            ]

                assert len(grad_sub_partitions) == num_comm_intervals

            if not postscale_gradients:
                raise NotImplementedError(
                    "pre-scale_gradients is not implemented")

            all_comm_partitions = []
            for comm_idx in range(num_comm_intervals):
                single_comm_all_partitions = []
                for rank in range(world_size):
                    single_comm_all_partitions.append(
                        all_sub_partitions[rank][comm_idx])
                dist.reduce_scatter(
                    output=single_comm_all_partitions[local_rank],
                    input_list=single_comm_all_partitions,
                    group=self.dp_process_group)

                if gradient_average:
                    for partition in single_comm_all_partitions:
                        partition.mul_(gradient_predivide_factor / world_size)

                all_comm_partitions.append(single_comm_all_partitions)

            # stitch together all rank sub partitions for each comm idx
            flat_comm_grads = []
            for comm_idx, rank_partitions in enumerate(all_comm_partitions):
                flat_comm_grads.append(torch.cat(rank_partitions))

            flat_all_grads = torch.cat(flat_comm_grads)

            # copy back reduced gradients but only those needed for this local rank
            for param, updated_grad in zip(
                    self.fp16_groups[i],
                    _unflatten_dense_tensors(flat_all_grads,
                                             self.fp16_groups[i])):
                if param in my_params:
                    param.grad.copy_(updated_grad)
Example #24
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns: final result of linear operation.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Example #25
0
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
backend = 'nccl'
dist.init_process_group(backend)
rank = dist.get_rank()

# simple dist
if rank == 0:
    print("pid is {}, rank is {}".format(os.getpid(), rank))
else:
    print("pid is {}, rank is {}".format(os.getpid(), rank))

torch.cuda.set_device(rank)
var = torch.tensor(rank + 10).to("cuda:{}".format(rank))
var_list = [torch.tensor(var).to("cuda:{}".format(rank)) for var in range(4)]

if rank == 0:
    dist.reduce_scatter(var, var_list, op=dist.ReduceOp.SUM, async_op=False)
else:
    dist.reduce_scatter(var, var_list, op=dist.ReduceOp.SUM, async_op=False)
print('Pid is ', os.getpid(), ', Rank ', rank, ', tensor data is: ', var)
Example #26
0
    def forward_backward(self, features, targets, optimizer):
        """
        Partial FC forward, which will sample positive weights and part of negative weights,
        then compute logits and get the grad of features.
        """
        total_targets = self.prepare(targets, optimizer)

        if self.world_size > 1:
            total_features = concat_all_gather(features)
        else:
            total_features = features.detach()

        total_features.requires_grad_(True)

        logits = self.forward(total_features)
        logits = self.cls_layer(logits, total_targets)

        # from ipdb import set_trace; set_trace()
        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            if self.world_size > 1:
                dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdim=True)

            if self.world_size > 1:
                dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_targets != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0],
                                        grad.size()[1]],
                                  device=grad.device)
            one_hot.scatter_(1, total_targets[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_targets[index, None])
            if self.world_size > 1:
                dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # calculate grad
            grad[index] -= one_hot
            grad.div_(logits.size(0))

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad: torch.Tensor = torch.zeros_like(features)
        # feature gradient all-reduce
        if self.world_size > 1:
            dist.reduce_scatter(
                x_grad, list(total_features.grad.chunk(self.world_size,
                                                       dim=0)))
        else:
            x_grad = total_features.grad
        x_grad = x_grad * self.world_size
        # backward backbone
        return x_grad, loss_v
Example #27
0
def _handle_row_wise_sharding(
    input,
    world_size,
    weight,
    local_shard,
    offsets,
    per_sample_weights,
    mode,
    max_norm,
    norm_type,
    padding_idx,
    rank,
    pg,
):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for embeddingBag. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding_bag.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: row-wise shared local weight used for lookup.
        offsets: list of start positions of each bag for 1D input.
        per_sample_weights: weights for weighted sum mode.
        mode: aggregation method of each bag.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
            Note: this will modify weight in-place.
        norm_type: The p in the p-norm to compute for the max_norm option.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
            Note that the embedding vector at padding_idx is
            excluded from the reduction.
        rank: # of cuda process.
        pg: process group.

    Returns:
        gathered_output: final result of lookup and aggregation.
    """
    # We sort each interval defined by offset. If 2D, each interval is a row.
    input_size = input.size()
    (
        input_split_sorted_list,
        input_split_sorted_indices,
        split_sizes_1d,
        split_sizes_1d_with_padding,
    ) = _input_split_sort(input, offsets, padding_idx)

    # Within each interval of the sorted list, we first need to distribute
    # each ID to different bucket(rank) and also ensure the rearrangement
    # has been done in case the placement idx not equal to rank.
    # We then perform some simple stats on each interval for the next step
    # If user specifies per_sample_weights we need to rearrange them
    # to be sync with IDs and then distribute them to each rank
    (
        input_combined,
        input_combined_split_sizes,
        offsets_rearrange_list,
        offsets_rearrange_sizes,
        per_sample_weights,
        sharded_dim_size_max,
        padding_idx,
    ) = _sorted_input_distribute_prepare(
        input_split_sorted_list,
        input_split_sorted_indices,
        world_size,
        input,
        weight,
        per_sample_weights,
        rank,
        padding_idx,
    )

    # Send ID/offsets/per_sample_weights to different bucket(rank).
    (
        gathered_input,
        output_offsets_tensor_list,
        output_split_sizes,
        gathered_per_sample_weights,
    ) = _distribute_input(
        input_combined,
        input_combined_split_sizes,
        offsets_rearrange_list,
        offsets_rearrange_sizes,
        sharded_dim_size_max,
        world_size,
        input,
        per_sample_weights,
        pg,
    )

    # Perform the embedding bag look-up and aggregation
    results = []
    for i, inp in enumerate(gathered_input):
        per_sample_weights = (
            gathered_per_sample_weights[i]
            if gathered_per_sample_weights is not None
            else None
        )
        # If input is None, passing in max_norm causes
        # errors in CUDA.
        if max_norm is not None and inp.size(0) == 0:
            max_norm = None

        # Perform local embedding look up and aggregation.
        result = torch.nn.functional.embedding_bag(
            inp,
            local_shard,
            offsets=output_offsets_tensor_list[i],
            mode=mode if mode != "mean" else "sum",
            per_sample_weights=per_sample_weights,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_idx,
        )
        if mode != "max":
            results.append(result)
        # For max case, it there is no look-up from some ranks
        # it will return all zero for that. For that case, we need
        # to set the row to neg inf; otherwise, in the final
        # aggregation negative values will be rounded up to zero.
        elif inp.size(0) == 0:
            result[:] = -float("Inf")
            results.append(result)
        else:
            for idx, current_offset in enumerate(output_offsets_tensor_list[i]):
                next_offset = current_offset
                if idx == len(output_offsets_tensor_list[i]) - 1:
                    next_offset = output_split_sizes[i]
                else:
                    next_offset = output_offsets_tensor_list[i][idx + 1]
                # When there is no interval in the current rank or all IDs
                # are equal to padding_idx, we then need to ensure they
                # don't contribute to the final result.
                if (current_offset == next_offset) or (
                    padding_idx is not None
                    and not torch.any(
                        torch.ne(inp[current_offset:next_offset], padding_idx)
                    )
                ):
                    result[idx] = -float("Inf")
            results.append(result)

    # Gather all the aggregated results appropriately by using reduce_scatter.
    row_size = input.size(0) if len(input_size) > 1 else len(split_sizes_1d)
    gathered_output = torch.empty(row_size, weight.size(1), device=input.device)
    op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX
    dist.reduce_scatter(gathered_output, results, op=op, group=pg)

    # For Mean, we cannot do the division until very end because the sum of means
    # not equal to the mean of sum. (Divisor is different)
    if mode == "mean":
        split_sizes_1d_tensor = torch.tensor(
            split_sizes_1d_with_padding, dtype=torch.float, device=input.device
        )
        # Make sure divisor is not zero.
        split_sizes_1d_tensor[split_sizes_1d_tensor == 0.0] = 1.0
        return (
            torch.div(gathered_output.t().contiguous(), split_sizes_1d_tensor)
            .t()
            .contiguous()
        )

    # Return the appropriate local result.
    return gathered_output
Example #28
0
def _handle_row_wise_sharding(input, world_size, weight, local_shard, offsets,
                              per_sample_weights, mode, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for embeddingBag. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding_bag.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: row-wise shared local weight used for lookup.
        offsets: list of start positions of each bag for 1D input.
        per_sample_weights: weights for weighted sum mode.
        mode: aggregation method of each bag.
        pg: process group.

    Returns:
        gathered_output: final result of lookup and aggregation.
    """
    # We sort each interval defined by offset. If 2D, each interval is a row.
    input_size = input.size()
    (
        input_split_sorted_list,
        input_split_sorted_indices,
        split_sizes_1d,
    ) = _input_split_sort(input, offsets)

    # Within each interval of the sorted list, we first need to distribute
    # each ID to different bucket(rank) and also ensure the rearrangement
    # has been done in case the placement idx not equal to rank.
    # We then perform some simple stats on each interval for the next step
    # If user specifies per_sample_weights we need to rearrange them
    # to be sync with IDs and then distribute them to each rank
    (
        input_combined,
        input_combined_split_sizes,
        offsets_rearrange_list,
        offsets_rearrange_sizes,
        per_sample_weights,
        sharded_dim_size_max,
    ) = _sorted_input_distribute_prepare(
        input_split_sorted_list,
        input_split_sorted_indices,
        world_size,
        input,
        weight,
        per_sample_weights,
    )

    # Send ID/offsets/per_sample_weights to different bucket(rank).
    (
        gathered_input,
        output_offsets_tensor_list,
        output_split_sizes,
        gathered_per_sample_weights,
    ) = _distribute_input(
        input_combined,
        input_combined_split_sizes,
        offsets_rearrange_list,
        offsets_rearrange_sizes,
        sharded_dim_size_max,
        world_size,
        input,
        per_sample_weights,
        pg,
    )

    # Perform the embedding bag look-up and aggregation
    results = []
    for i, inp in enumerate(gathered_input):
        per_sample_weights = (gathered_per_sample_weights[i]
                              if gathered_per_sample_weights is not None else
                              None)
        result = torch.nn.functional.embedding_bag(
            inp,
            local_shard,
            offsets=output_offsets_tensor_list[i],
            mode=mode if mode != "mean" else "sum",
            per_sample_weights=per_sample_weights,
        )
        if mode != "max":
            results.append(result)
        # For max case, it there is no look-up from some ranks
        # it will return all zero for that. For that case, we need
        # to set the row to neg inf; otherwise, in the final
        # aggregation negative values will be rounded up to zero.
        elif inp.size(0) == 0:
            result[:] = -float("Inf")
            results.append(result)
        else:
            for idx, current_offset in enumerate(
                    output_offsets_tensor_list[i]):
                next_offset = current_offset
                if idx == len(output_offsets_tensor_list[i]) - 1:
                    next_offset = output_split_sizes[i]
                else:
                    next_offset = output_offsets_tensor_list[i][idx + 1]
                if current_offset == next_offset:
                    result[idx] = -float("Inf")
            results.append(result)

    # Gather all the aggregated results appropriately by using reduce_scatter.
    row_size = input.size(0) if len(input_size) > 1 else len(split_sizes_1d)
    gathered_output = torch.empty(row_size,
                                  weight.size(1),
                                  device=input.device)
    op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX
    dist.reduce_scatter(gathered_output, results, op=op, group=pg)

    # For Mean, we cannot do the division until very end because the sum of means
    # not equal to the mean of sum. (Divisor is different)
    if mode == "mean":
        # For a 2D tensor, we just average by col number
        if len(input_size) > 1:
            return torch.div(gathered_output, float(input.size(1)))
        # For a 1D tensor, we need to average by each offset
        else:
            split_sizes_1d_tensor = torch.tensor(split_sizes_1d,
                                                 dtype=torch.float,
                                                 device=input.device)
            # Make sure divisor is not zero.
            split_sizes_1d_tensor[split_sizes_1d_tensor == 0.0] = 1.0
            return (torch.div(gathered_output.t().contiguous(),
                              split_sizes_1d_tensor).t().contiguous())

    # Return the appropriate local result.
    return gathered_output