def validate_flags(cat_feature_count):
    if FLAGS.max_table_size is not None and not FLAGS.hash_indices:
        raise ValueError('Hash indices must be True when setting a max_table_size')

    if FLAGS.base_device == 'cpu':
        if FLAGS.embedding_type in ('joint_fused', 'joint_sparse'):
            print('WARNING: CUDA joint embeddings are not supported on CPU')
            FLAGS.embedding_type = 'joint'

        if FLAGS.amp:
            print('WARNING: Automatic mixed precision not supported on CPU')
            FLAGS.amp = False

        if FLAGS.optimized_mlp:
            print('WARNING: Optimized MLP is not supported on CPU')
            FLAGS.optimized_mlp = False

    if FLAGS.embedding_type == 'custom_cuda':
        if (not is_distributed()) and FLAGS.embedding_dim == 128 and cat_feature_count == 26:
            FLAGS.embedding_type = 'joint_fused'
        else:
            FLAGS.embedding_type = 'joint_sparse'

    if FLAGS.embedding_type == 'joint_fused' and FLAGS.embedding_dim != 128:
        print('WARNING: Joint fused can be used only with embedding_dim=128. Changed embedding type to joint_sparse.')
        FLAGS.embedding_type = 'joint_sparse'

    if FLAGS.dataset is None and (FLAGS.dataset_type != 'synthetic_gpu' or
                                  FLAGS.synthetic_dataset_use_feature_spec):
        raise ValueError('Dataset argument has to specify a path to the dataset')

    FLAGS.inference_benchmark_batch_sizes = [int(x) for x in FLAGS.inference_benchmark_batch_sizes]
    FLAGS.top_mlp_sizes = [int(x) for x in FLAGS.top_mlp_sizes]
    FLAGS.bottom_mlp_sizes = [int(x) for x in FLAGS.bottom_mlp_sizes]
def create_embeddings(embedding_type: str,
                      categorical_feature_sizes: Sequence[int],
                      embedding_dim: int,
                      device: str = "cuda",
                      hash_indices: bool = False,
                      fp16: bool = False) -> Embeddings:
    if embedding_type == "joint":
        return JointEmbedding(categorical_feature_sizes,
                              embedding_dim,
                              device=device,
                              hash_indices=hash_indices)
    elif embedding_type == "joint_fused":
        assert not is_distributed(), "Joint fused embedding is not supported in the distributed mode. " \
                                     "You may want to use 'joint_sparse' option instead."
        return FusedJointEmbedding(categorical_feature_sizes,
                                   embedding_dim,
                                   device=device,
                                   hash_indices=hash_indices,
                                   amp_train=fp16)
    elif embedding_type == "joint_sparse":
        return JointSparseEmbedding(categorical_feature_sizes,
                                    embedding_dim,
                                    device=device,
                                    hash_indices=hash_indices)
    elif embedding_type == "multi_table":
        return MultiTableEmbeddings(categorical_feature_sizes,
                                    embedding_dim,
                                    hash_indices=hash_indices,
                                    device=device)
    else:
        raise NotImplementedError(f"unknown embedding type: {embedding_type}")
Exemple #3
0
    def create_datasets(self) -> Tuple[Dataset, Dataset]:
        synthetic_train, synthetic_test = create_synthetic_datasets(self._flags)

        if is_distributed():
            self._synchronized_write(synthetic_train, synthetic_test)
        else:
            self._write(synthetic_train, synthetic_test)

        return create_real_datasets(self._flags, self._flags.synthetic_dataset_dir)
Exemple #4
0
    def create_datasets(self) -> Tuple[Dataset, Dataset]:
        synthetic_train, synthetic_test = create_synthetic_datasets(self._flags)

        if is_distributed():
            self._synchronized_write(synthetic_train, synthetic_test)
        else:
            self._write(synthetic_train, synthetic_test)

        return create_real_datasets(
            self._flags, self._flags.synthetic_dataset_dir,
            SplitCriteoDataset, "train", "test",
            prefetch_depth=10
        )
def create_dataset_factory(
        flags,
        feature_spec: FeatureSpec,
        device_mapping: Optional[dict] = None) -> DatasetFactory:
    """
    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
    (see `DatasetFactory#create_collate_fn`).

    :param flags:
    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
    :return:
    """
    dataset_type = flags.dataset_type
    num_numerical_features = feature_spec.get_number_of_numerical_features()
    if is_distributed() or device_mapping:
        assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
        rank = get_rank()
        local_categorical_positions = device_mapping["embedding"][rank]
        numerical_features_enabled = device_mapping["bottom_mlp"] == rank
    else:
        local_categorical_positions = list(
            range(len(feature_spec.get_categorical_feature_names())))
        numerical_features_enabled = True

    if dataset_type == "parametric":
        local_categorical_names = feature_spec.cat_positions_to_names(
            local_categorical_positions)
        return ParametricDatasetFactory(
            flags=flags,
            feature_spec=feature_spec,
            numerical_features_enabled=numerical_features_enabled,
            categorical_features_to_read=local_categorical_names)
    if dataset_type == "synthetic_gpu":
        local_numerical_features = num_numerical_features if numerical_features_enabled else 0
        world_categorical_sizes = feature_spec.get_categorical_sizes()
        local_categorical_sizes = [
            world_categorical_sizes[i] for i in local_categorical_positions
        ]
        return SyntheticGpuDatasetFactory(
            flags,
            local_numerical_features_num=local_numerical_features,
            local_categorical_feature_sizes=local_categorical_sizes)

    raise NotImplementedError(f"unknown dataset type: {dataset_type}")
Exemple #6
0
def create_dataset_factory(flags,
                           device_mapping: Optional[dict] = None
                           ) -> DatasetFactory:
    """
    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
    (see `DatasetFactory#create_collate_fn`).

    :param flags:
    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
    :return:
    """
    dataset_type = flags.dataset_type

    if dataset_type == "binary":
        return BinaryDatasetFactory(flags, device_mapping)

    if dataset_type == "split":
        if is_distributed():
            assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
            rank = get_rank()
            return SplitBinaryDatasetFactory(
                flags=flags,
                numerical_features=device_mapping["bottom_mlp"] == rank,
                categorical_features=device_mapping["embedding"][rank])
        return SplitBinaryDatasetFactory(
            flags=flags,
            numerical_features=True,
            categorical_features=range(
                len(get_categorical_feature_sizes(flags))))

    if dataset_type == "synthetic_gpu":
        return SyntheticGpuDatasetFactory(flags, device_mapping)

    if dataset_type == "synthetic_disk":
        return SyntheticDiskDatasetFactory(flags, device_mapping)

    raise NotImplementedError(f"unknown dataset type: {dataset_type}")
Exemple #7
0
 def create_sampler(self, dataset: Dataset) -> Optional[Sampler]:
     return RandomDistributedSampler(
         dataset) if is_distributed() else RandomSampler(dataset)
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    device = FLAGS.base_device

    if not is_distributed():
        raise NotImplementedError(
            "This file is only for distributed training.")

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    print("Creating data loaders")

    FLAGS.set_default("test_batch_size",
                      FLAGS.test_batch_size // world_size * world_size)

    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
    device_mapping = get_device_mapping(categorical_feature_sizes,
                                        num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size,
                                              num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))

    # sizes of embeddings for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[
        'bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(
        FLAGS, device_mapping=device_mapping)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=FLAGS.num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device)
    print(model)
    print(device_mapping)
    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr
    scaled_lrs = [scaled_lr / world_size, scaled_lr]

    embedding_optimizer = torch.optim.SGD([
        {
            'params': model.bottom_model.embeddings.parameters(),
            'lr': scaled_lrs[0]
        },
    ])
    mlp_optimizer = apex_optim.FusedSGD([{
        'params':
        model.bottom_model.mlp.parameters(),
        'lr':
        scaled_lrs[0]
    }, {
        'params': model.top_model.parameters(),
        'lr': scaled_lrs[1]
    }])

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict())

    checkpoint_loader = make_distributed_checkpoint_loader(
        device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    if FLAGS.amp:
        (model.top_model,
         model.bottom_model.mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    if FLAGS.mode == 'test':
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(F"Finished testing. Test auc {auc:.4f}")
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(
    ):
        logging.warning(
            "Saving checkpoint without --bottom_features_ordered flag will result in "
            "a device-order dependent model. Consider using --bottom_features_ordered "
            "if you plan to load the checkpoint in different device configurations."
        )

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader_train)
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)
    moving_loss_stream = torch.cuda.Stream()

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[scaled_lrs, [scaled_lrs[0]]],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    stop_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    F"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            lr_scheduler.step()

            if click.shape[0] != FLAGS.batch_size:  # last batch
                logging.error("The last batch with size %s is not supported",
                              click.shape[0])
            else:
                output = model(numerical_features, categorical_features,
                               batch_sizes_per_gpu).squeeze()

                loss = loss_fn(
                    output, click[batch_indices[rank]:batch_indices[rank + 1]])

                # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
                for param_group in itertools.chain(
                        embedding_optimizer.param_groups,
                        mlp_optimizer.param_groups):
                    for param in param_group['params']:
                        param.grad = None

                if FLAGS.amp:
                    loss *= FLAGS.loss_scale
                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                mlp_optimizer.step()
                embedding_optimizer.step()

                moving_loss_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(F"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.current_stream().wait_stream(moving_loss_stream)
                # Averaging cross a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[1]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[1]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                stop_time = time()

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(model, data_loader_test)

                if auc is None:
                    continue

                print(F"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path,
                                          epoch, step)

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    dllogger.log(data=results, step=tuple())