Exemplo n.º 1
0
def create_model():
    print("Creating model")

    FLAGS.top_mlp_sizes = [int(s) for s in FLAGS.top_mlp_sizes]
    FLAGS.bottom_mlp_sizes = [int(s) for s in FLAGS.bottom_mlp_sizes]

    model_config = {
        'top_mlp_sizes': FLAGS.top_mlp_sizes,
        'bottom_mlp_sizes': FLAGS.bottom_mlp_sizes,
        'embedding_type': FLAGS.embedding_type,
        'embedding_dim': FLAGS.embedding_dim,
        'interaction_op': FLAGS.interaction_op,
        'categorical_feature_sizes': get_categorical_feature_sizes(FLAGS),
        'num_numerical_features': FLAGS.num_numerical_features,
        'hash_indices': FLAGS.hash_indices,
        'use_cpp_mlp': FLAGS.optimized_mlp,
        'fp16': FLAGS.amp,
        'base_device': FLAGS.base_device,
    }

    model = Dlrm.from_dict(model_config)
    print(model)

    model.to(FLAGS.base_device)

    if FLAGS.load_checkpoint_path is not None:
        checkpoint_loader = make_serial_checkpoint_loader(
            embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
            device="cpu")
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(FLAGS.base_device)

    return model
Exemplo n.º 2
0
def create_synthetic_datasets(flags, device_mapping: Optional[Dict] = None):
    dataset_train = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries,
                                     batch_size=flags.batch_size,
                                     numerical_features=flags.num_numerical_features,
                                     categorical_feature_sizes=get_categorical_feature_sizes(flags),
                                     device_mapping=device_mapping)

    dataset_test = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries,
                                    batch_size=flags.test_batch_size,
                                    numerical_features=flags.num_numerical_features,
                                    categorical_feature_sizes=get_categorical_feature_sizes(flags),
                                    device_mapping=device_mapping)
    return dataset_train, dataset_test
Exemplo n.º 3
0
def create_real_datasets(
    flags,
    path,
    dataset_class: type = SplitCriteoDataset,
    train_dataset_path="train",
    test_dataset_path="test",
    **kwargs
):
    train_dataset = os.path.join(path, train_dataset_path)
    test_dataset = os.path.join(path, test_dataset_path)
    categorical_sizes = get_categorical_feature_sizes(flags)

    dataset_train = dataset_class(
        data_path=train_dataset,
        batch_size=flags.batch_size,
        numerical_features=flags.num_numerical_features,
        categorical_features=range(len(categorical_sizes)),
        categorical_feature_sizes=categorical_sizes,
        **kwargs
    )

    dataset_test = dataset_class(
        data_path=test_dataset,
        batch_size=flags.test_batch_size,
        numerical_features=flags.num_numerical_features,
        categorical_features=range(len(categorical_sizes)),
        categorical_feature_sizes=categorical_sizes,
        **kwargs
    )

    return dataset_train, dataset_test
Exemplo n.º 4
0
    def create_datasets(self) -> Tuple[Dataset, Dataset]:
        train_dataset_path = os.path.join(self._flags.dataset, "train")
        test_dataset_path = os.path.join(self._flags.dataset, "test")
        categorical_sizes = get_categorical_feature_sizes(self._flags)

        # prefetching is currently unsupported if using the batch-wise shuffle
        prefetch_depth = 0 if self._flags.shuffle_batch_order else 10

        dataset_train = SplitCriteoDataset(
            data_path=train_dataset_path,
            batch_size=self._flags.batch_size,
            numerical_features=self._numerical_features,
            categorical_features=self._categorical_features,
            categorical_feature_sizes=categorical_sizes,
            prefetch_depth=prefetch_depth
        )

        dataset_test = SplitCriteoDataset(
            data_path=test_dataset_path,
            batch_size=self._flags.test_batch_size,
            numerical_features=self._numerical_features,
            categorical_features=self._categorical_features,
            categorical_feature_sizes=categorical_sizes,
            prefetch_depth=prefetch_depth
        )

        return dataset_train, dataset_test
Exemplo n.º 5
0
    def create_datasets(self) -> Tuple[Dataset, Dataset]:
        train_dataset_path = os.path.join(self._flags.dataset, "train")
        test_dataset_path = os.path.join(self._flags.dataset, "test")
        categorical_sizes = get_categorical_feature_sizes(self._flags)

        dataset_train = SplitCriteoDataset(
            data_path=train_dataset_path,
            batch_size=self._flags.batch_size,
            numerical_features=self._numerical_features,
            categorical_features=self._categorical_features,
            categorical_feature_sizes=categorical_sizes
        )
        dataset_test = SplitCriteoDataset(
            data_path=test_dataset_path,
            batch_size=self._flags.test_batch_size,
            numerical_features=self._numerical_features,
            categorical_features=self._categorical_features,
            categorical_feature_sizes=categorical_sizes
        )
        return dataset_train, dataset_test
Exemplo n.º 6
0
def create_real_datasets(flags, path, dataset_class: type = CriteoBinDataset):
    train_dataset = os.path.join(path, "train_data.bin")
    test_dataset = os.path.join(path, "test_data.bin")
    categorical_sizes = get_categorical_feature_sizes(flags)

    dataset_train = dataset_class(
        data_file=train_dataset,
        batch_size=flags.batch_size,
        subset=flags.dataset_subset,
        numerical_features=flags.num_numerical_features,
        categorical_features=len(categorical_sizes),
    )

    dataset_test = dataset_class(
        data_file=test_dataset,
        batch_size=flags.test_batch_size,
        numerical_features=flags.num_numerical_features,
        categorical_features=len(categorical_sizes),
    )

    return dataset_train, dataset_test
Exemplo n.º 7
0
def create_dataset_factory(flags,
                           device_mapping: Optional[dict] = None
                           ) -> DatasetFactory:
    """
    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
    (see `DatasetFactory#create_collate_fn`).

    :param flags:
    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
    :return:
    """
    dataset_type = flags.dataset_type

    if dataset_type == "binary":
        return BinaryDatasetFactory(flags, device_mapping)

    if dataset_type == "split":
        if is_distributed():
            assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
            rank = get_rank()
            return SplitBinaryDatasetFactory(
                flags=flags,
                numerical_features=device_mapping["bottom_mlp"] == rank,
                categorical_features=device_mapping["embedding"][rank])
        return SplitBinaryDatasetFactory(
            flags=flags,
            numerical_features=True,
            categorical_features=range(
                len(get_categorical_feature_sizes(flags))))

    if dataset_type == "synthetic_gpu":
        return SyntheticGpuDatasetFactory(flags, device_mapping)

    if dataset_type == "synthetic_disk":
        return SyntheticDiskDatasetFactory(flags, device_mapping)

    raise NotImplementedError(f"unknown dataset type: {dataset_type}")
Exemplo n.º 8
0
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    prefetching_enabled = is_data_prefetching_enabled()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    checkpoint_writer = make_serial_checkpoint_writer(
        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
        config=FLAGS.flag_values_dict())

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    timer.click()

    for epoch in range(FLAGS.epochs):
        input_pipeline = iter(data_loader_train)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, batch in enumerate(input_pipeline):
            global_step = steps_per_epoch * epoch + step
            numerical_features, categorical_features, click = batch

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            # Setting grad to None is faster than zero_grad()
            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.amp:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            if step % print_freq == 0 and step > 0:
                loss_value = loss.item()

                timer.click()

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(loss=loss_value,
                                         lr=optimizer.param_groups[0]["lr"])
                else:
                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
                    metric_logger.update(
                        loss=loss_value / unscale_factor,
                        step_time=timer.measured / FLAGS.print_freq,
                        lr=optimizer.param_groups[0]["lr"] * unscale_factor)

                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        f'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (global_step % test_freq == 0 and global_step > 0
                    and global_step / steps_per_epoch >= FLAGS.test_after):
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    f"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(checkpoint_writer, model,
                                          FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    stop_time = time()
    run_time_s = int(stop_time - start_time)

    print(
        f"Finished training in {run_time_s}s. "
        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
    )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())