예제 #1
0
    def backward(ctx, grad_slice_bottom_outputs):
        rank = dist.get_rank()

        if ctx.feature_order is not None and ctx.device_feature_order is not None:
            grad_slice_bottom_outputs = grad_slice_bottom_outputs[:, ctx.
                                                                  device_feature_order, :]

        grad_local_bottom_outputs = torch.empty(
            sum(ctx.batch_sizes_per_gpu),
            ctx.vectors_per_gpu[rank] * ctx.vector_dim,
            device=grad_slice_bottom_outputs.device,
            dtype=grad_slice_bottom_outputs.dtype)
        # All to all only takes list while split() returns tuple

        grad_local_bottom_outputs_split = list(
            grad_local_bottom_outputs.split(ctx.batch_sizes_per_gpu, dim=0))

        split_grads = [
            t.contiguous() for t in (grad_slice_bottom_outputs.view(
                ctx.batch_sizes_per_gpu[rank], -1).split(
                    [ctx.vector_dim * n for n in ctx.vectors_per_gpu], dim=1))
        ]

        torch.distributed.all_to_all(grad_local_bottom_outputs_split,
                                     split_grads)

        return (grad_local_bottom_outputs.view(
            grad_local_bottom_outputs.shape[0], -1,
            ctx.vector_dim), None, None, None, None, None)
    def __init__(self,
                 num_entries: int,
                 device: str = 'cuda',
                 batch_size: int = 32768,
                 numerical_features: Optional[int] = None,
                 categorical_feature_sizes: Optional[Sequence[int]] = None,
                 device_mapping: Optional[Dict[str, Any]] = None):
        if device_mapping:
            # distributed setting
            rank = get_rank()
            numerical_features = numerical_features if device_mapping[
                "bottom_mlp"] == rank else None
            categorical_feature_sizes = device_mapping["embedding"][rank]

        self.cat_features_count = len(
            categorical_feature_sizes
        ) if categorical_feature_sizes is not None else 0
        self.num_features_count = numerical_features if numerical_features is not None else 0

        self.tot_fea = 1 + self.num_features_count + self.cat_features_count
        self.batch_size = batch_size
        self.batches_per_epoch = math.ceil(num_entries / batch_size)
        self.categorical_feature_sizes = categorical_feature_sizes
        self.device = device

        self.tensor = torch.randint(low=0,
                                    high=2,
                                    size=(self.batch_size, self.tot_fea),
                                    device=self.device)
        self.tensor = self.tensor.float()
예제 #3
0
    def forward(ctx,
                local_bottom_outputs: torch.Tensor,
                batch_sizes_per_gpu: Sequence[int],
                vector_dim: int,
                vectors_per_gpu: Sequence[int],
                feature_order: Optional[torch.Tensor] = None,
                device_feature_order: Optional[torch.Tensor] = None):
        """
        Args:
            ctx : Pytorch convention
            local_bottom_outputs (Tensor): Concatenated output of bottom model
            batch_sizes_per_gpu (Sequence[int]):
            vector_dim (int):
            vectors_per_gpu (Sequence[int]): Note, bottom MLP is considered as 1 vector
            device_feature_order:
            feature_order:

        Returns:
            slice_embedding_outputs (Tensor): Patial output from bottom model to feed into data parallel top model
        """
        rank = dist.get_rank()

        ctx.world_size = torch.distributed.get_world_size()
        ctx.batch_sizes_per_gpu = batch_sizes_per_gpu
        ctx.vector_dim = vector_dim
        ctx.vectors_per_gpu = vectors_per_gpu
        ctx.feature_order = feature_order
        ctx.device_feature_order = device_feature_order

        # Buffer shouldn't need to be zero out. If not zero out buffer affecting accuracy, there must be a bug.
        bottom_output_buffer = [
            torch.empty(batch_sizes_per_gpu[rank],
                        n * vector_dim,
                        device=local_bottom_outputs.device,
                        dtype=local_bottom_outputs.dtype)
            for n in vectors_per_gpu
        ]

        torch.distributed.all_to_all(
            bottom_output_buffer,
            list(local_bottom_outputs.split(batch_sizes_per_gpu, dim=0)))
        slice_bottom_outputs = torch.cat(bottom_output_buffer,
                                         dim=1).view(batch_sizes_per_gpu[rank],
                                                     -1, vector_dim)

        # feature reordering is just for consistency across different device mapping configurations
        if feature_order is not None and device_feature_order is not None:
            return slice_bottom_outputs[:, feature_order, :]

        return slice_bottom_outputs
예제 #4
0
    def create_collate_fn(self) -> Optional[Callable]:
        if self._device_mapping is not None:
            # selection of categorical features assigned to this device
            device_cat_features = torch.tensor(
                self._device_mapping["embedding"][get_rank()], device=self._flags.base_device, dtype=torch.long)
        else:
            device_cat_features = None

        orig_stream = torch.cuda.current_stream() if self._flags.base_device == 'cuda' else None
        return functools.partial(
            collate_array,
            device=self._flags.base_device,
            orig_stream=orig_stream,
            num_numerical_features=self._flags.num_numerical_features,
            selected_categorical_features=device_cat_features
        )
예제 #5
0
def create_dataset_factory(
        flags,
        feature_spec: FeatureSpec,
        device_mapping: Optional[dict] = None) -> DatasetFactory:
    """
    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
    (see `DatasetFactory#create_collate_fn`).

    :param flags:
    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
    :return:
    """
    dataset_type = flags.dataset_type
    num_numerical_features = feature_spec.get_number_of_numerical_features()
    if is_distributed() or device_mapping:
        assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
        rank = get_rank()
        local_categorical_positions = device_mapping["embedding"][rank]
        numerical_features_enabled = device_mapping["bottom_mlp"] == rank
    else:
        local_categorical_positions = list(
            range(len(feature_spec.get_categorical_feature_names())))
        numerical_features_enabled = True

    if dataset_type == "parametric":
        local_categorical_names = feature_spec.cat_positions_to_names(
            local_categorical_positions)
        return ParametricDatasetFactory(
            flags=flags,
            feature_spec=feature_spec,
            numerical_features_enabled=numerical_features_enabled,
            categorical_features_to_read=local_categorical_names)
    if dataset_type == "synthetic_gpu":
        local_numerical_features = num_numerical_features if numerical_features_enabled else 0
        world_categorical_sizes = feature_spec.get_categorical_sizes()
        local_categorical_sizes = [
            world_categorical_sizes[i] for i in local_categorical_positions
        ]
        return SyntheticGpuDatasetFactory(
            flags,
            local_numerical_features_num=local_numerical_features,
            local_categorical_feature_sizes=local_categorical_sizes)

    raise NotImplementedError(f"unknown dataset type: {dataset_type}")
예제 #6
0
def create_dataset_factory(flags,
                           device_mapping: Optional[dict] = None
                           ) -> DatasetFactory:
    """
    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
    (see `DatasetFactory#create_collate_fn`).

    :param flags:
    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
    :return:
    """
    dataset_type = flags.dataset_type

    if dataset_type == "binary":
        return BinaryDatasetFactory(flags, device_mapping)

    if dataset_type == "split":
        if is_distributed():
            assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
            rank = get_rank()
            return SplitBinaryDatasetFactory(
                flags=flags,
                numerical_features=device_mapping["bottom_mlp"] == rank,
                categorical_features=device_mapping["embedding"][rank])
        return SplitBinaryDatasetFactory(
            flags=flags,
            numerical_features=True,
            categorical_features=range(
                len(get_categorical_feature_sizes(flags))))

    if dataset_type == "synthetic_gpu":
        return SyntheticGpuDatasetFactory(flags, device_mapping)

    if dataset_type == "synthetic_disk":
        return SyntheticDiskDatasetFactory(flags, device_mapping)

    raise NotImplementedError(f"unknown dataset type: {dataset_type}")
예제 #7
0
def dist_evaluate(model, data_loader, data_cache):
    """Test distributed DLRM model

    Args:
        model (DistDLRM):
        data_loader (torch.utils.data.DataLoader):
    """
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    device_mapping = dist_model.get_criteo_device_mapping(world_size)
    vectors_per_gpu = device_mapping['vectors_per_gpu']

    # Test batch size could be big, make sure it prints
    default_print_freq = max(16384 * 2000 // FLAGS.test_batch_size, 1)
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f} ms'))
    local_embedding_device_mapping = torch.tensor(
        device_mapping['embedding'][rank],
        device=FLAGS.device,
        dtype=torch.long)
    with torch.no_grad():
        # ROC can be computed per batch and then compute AUC globally, but I don't have the code.
        # So pack all the outputs and labels together to compute AUC. y_true and y_score naming follows sklearn
        y_true = []
        y_score = []
        data_stream = torch.cuda.Stream()
        stop_time = time()

        if data_cache is None or not data_cache:
            eval_data_iter = dataset.prefetcher(iter(data_loader), data_stream)
        else:
            print("Use cached eval data")
            eval_data_iter = data_cache
        for step, (numerical_features, categorical_features,
                   click) in enumerate(eval_data_iter):
            if data_cache is not None and len(data_cache) < steps_per_epoch:
                data_cache.append(
                    (numerical_features, categorical_features, click))
            last_batch_size = None
            if click.shape[0] != FLAGS.test_batch_size:  # last batch
                last_batch_size = click.shape[0]
                logging.debug("Pad the last test batch of size %d to %d",
                              last_batch_size, FLAGS.test_batch_size)
                padding_size = FLAGS.test_batch_size - last_batch_size
                padding_numiercal = torch.empty(
                    padding_size,
                    numerical_features.shape[1],
                    device=numerical_features.device,
                    dtype=numerical_features.dtype)
                numerical_features = torch.cat(
                    (numerical_features, padding_numiercal), dim=0)
                if categorical_features is not None:
                    padding_categorical = torch.ones(
                        padding_size,
                        categorical_features.shape[1],
                        device=categorical_features.device,
                        dtype=categorical_features.dtype)
                    categorical_features = torch.cat(
                        (categorical_features, padding_categorical), dim=0)

            if FLAGS.dataset_type != "dist":
                categorical_features = categorical_features[:,
                                                            local_embedding_device_mapping]

            if FLAGS.fp16 and categorical_features is not None:
                numerical_features = numerical_features.to(torch.float16)
            bottom_out = model.bottom_model(numerical_features,
                                            categorical_features)
            batch_size_per_gpu = FLAGS.test_batch_size // world_size
            from_bottom = dist_model.bottom_to_top(bottom_out,
                                                   batch_size_per_gpu,
                                                   model.embedding_dim,
                                                   vectors_per_gpu)

            output = model.top_model(from_bottom).squeeze()

            buffer_dtype = torch.float32 if not FLAGS.fp16 else torch.float16
            output_receive_buffer = torch.empty(FLAGS.test_batch_size,
                                                device=FLAGS.device,
                                                dtype=buffer_dtype)
            torch.distributed.all_gather(
                list(output_receive_buffer.split(batch_size_per_gpu)), output)
            if last_batch_size is not None:
                output_receive_buffer = output_receive_buffer[:last_batch_size]

            y_true.append(click)
            y_score.append(output_receive_buffer.float())

            if step % print_freq == 0 and step != 0:
                torch.cuda.synchronize()
                metric_logger.update(step_time=(time() - stop_time) * 1000 /
                                     print_freq)
                stop_time = time()
                metric_logger.print(header=F"Test: [{step}/{steps_per_epoch}]")

        auc = metrics.roc_auc_score(torch.cat(y_true),
                                    torch.sigmoid(torch.cat(y_score).float()))

    return auc