def evaluate(model, loss_fn, data_loader):
    """Test dlrm model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        data_loader (torch.utils.data.DataLoader):
    """
    model.eval()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq

    steps_per_epoch = len(data_loader)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time',
        utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    with torch.no_grad():
        y_true = []
        y_score = []

        timer = utils.StepTimer()
        batch_iter = iter(data_loader)

        timer.click()
        for step in range(len(data_loader)):
            numerical_features, categorical_features, click = next(batch_iter)

            categorical_features = categorical_features.to(base_device).to(
                torch.long)
            numerical_features = numerical_features.to(base_device)
            click = click.to(torch.float32).to(base_device)

            if FLAGS.fp16:
                numerical_features = numerical_features.half()

            output = model(numerical_features, categorical_features).squeeze()

            loss = loss_fn(output, click)
            y_true.append(click)
            y_score.append(output)

            loss_value = loss.item()
            timer.click()

            if timer.measured is not None:
                metric_logger.update(loss=loss_value, step_time=timer.measured)
                if step % print_freq == 0 and step > 0:
                    metric_logger.print(
                        header=F"Test: [{step}/{steps_per_epoch}]")

        y_true = torch.cat(y_true).cpu().numpy()
        y_score = torch.cat(y_score).cpu().numpy()
        auc = roc_auc_score(y_true=y_true, y_score=y_score)

    model.train()

    return metric_logger.loss.global_avg, auc, metric_logger.step_time.avg
示例#2
0
def evaluate(model, loss_fn, data_loader):
    """Test dlrm model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        data_loader (torch.utils.data.DataLoader):
    """
    model.eval()
    print_freq = FLAGS.print_freq
    prefetching_enabled = is_data_prefetching_enabled()

    steps_per_epoch = len(data_loader)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    with torch.no_grad():
        y_true = []
        y_score = []

        timer = utils.StepTimer()
        timer.click()

        input_pipeline = iter(data_loader)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, (numerical_features, categorical_features,
                   click) in enumerate(input_pipeline):
            if FLAGS.amp:
                numerical_features = numerical_features.half()

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features, categorical_features).squeeze()

            loss = loss_fn(output, click)
            y_true.append(click)
            y_score.append(output)

            loss_value = loss.item()
            timer.click()

            if timer.measured is not None:
                metric_logger.update(loss=loss_value, step_time=timer.measured)
                if step % print_freq == 0 and step > 0:
                    metric_logger.print(
                        header=f"Test: [{step}/{steps_per_epoch}]")

        y_true = torch.cat(y_true)
        y_score = torch.cat(y_score)

        before_auc_timestamp = time()
        auc = utils.roc_auc_score(y_true=y_true, y_score=y_score)
        print(f'AUC computation took: {time() - before_auc_timestamp:.2f} [s]')

    model.train()

    return metric_logger.loss.global_avg, auc, metric_logger.step_time.avg
示例#3
0
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    prefetching_enabled = is_data_prefetching_enabled()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    checkpoint_writer = make_serial_checkpoint_writer(
        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
        config=FLAGS.flag_values_dict())

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    timer.click()

    for epoch in range(FLAGS.epochs):
        input_pipeline = iter(data_loader_train)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, batch in enumerate(input_pipeline):
            global_step = steps_per_epoch * epoch + step
            numerical_features, categorical_features, click = batch

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            # Setting grad to None is faster than zero_grad()
            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.amp:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            if step % print_freq == 0 and step > 0:
                loss_value = loss.item()

                timer.click()

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(loss=loss_value,
                                         lr=optimizer.param_groups[0]["lr"])
                else:
                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
                    metric_logger.update(
                        loss=loss_value / unscale_factor,
                        step_time=timer.measured / FLAGS.print_freq,
                        lr=optimizer.param_groups[0]["lr"] * unscale_factor)

                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        f'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (global_step % test_freq == 0 and global_step > 0
                    and global_step / steps_per_epoch >= FLAGS.test_after):
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    f"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(checkpoint_writer, model,
                                          FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    stop_time = time()
    run_time_s = int(stop_time - start_time)

    print(
        f"Finished training in {run_time_s}s. "
        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
    )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    device = FLAGS.base_device

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    print("Creating data loaders")

    FLAGS.set_default("test_batch_size",
                      FLAGS.test_batch_size // world_size * world_size)

    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
    device_mapping = get_device_mapping(categorical_feature_sizes,
                                        num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size,
                                              num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))

    # sizes of embeddings for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[
        'bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(
        FLAGS, device_mapping=device_mapping)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=FLAGS.num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device)
    print(model)
    print(device_mapping)
    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr

    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = scaled_lr
    else:
        embedding_model_parallel_lr = scaled_lr / world_size
    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = scaled_lr
    else:
        MLP_model_parallel_lr = scaled_lr / world_size
    data_parallel_lr = scaled_lr

    if is_main_process():
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }, {
            'params': list(model.bottom_model.mlp.parameters()),
            'lr': MLP_model_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params':
        list(model.bottom_model.embeddings.parameters()),
        'lr':
        embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict())

    checkpoint_loader = make_distributed_checkpoint_loader(
        device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    if FLAGS.amp:
        (model.top_model,
         model.bottom_model.mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    if FLAGS.mode == 'test':
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(
    ):
        logging.warning(
            "Saving checkpoint without --bottom_features_ordered flag will result in "
            "a device-order dependent model. Consider using --bottom_features_ordered "
            "if you plan to load the checkpoint in different device configurations."
        )

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader_train)
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)
    moving_loss_stream = torch.cuda.Stream()

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[mlp_lrs, embedding_lrs],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    stop_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            lr_scheduler.step()

            if click.shape[0] != FLAGS.batch_size:  # last batch
                logging.error("The last batch with size %s is not supported",
                              click.shape[0])
            else:
                output = model(numerical_features, categorical_features,
                               batch_sizes_per_gpu).squeeze()

                loss = loss_fn(
                    output, click[batch_indices[rank]:batch_indices[rank + 1]])

                if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
                    model.zero_grad()
                else:
                    # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
                    for param_group in itertools.chain(
                            embedding_optimizer.param_groups,
                            mlp_optimizer.param_groups):
                        for param in param_group['params']:
                            param.grad = None

                if FLAGS.amp:
                    loss *= FLAGS.loss_scale
                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if FLAGS.Adam_MLP_optimizer:
                    scale_MLP_gradients(mlp_optimizer, world_size)
                mlp_optimizer.step()

                if FLAGS.Adam_embedding_optimizer:
                    scale_embeddings_gradients(embedding_optimizer, world_size)
                embedding_optimizer.step()

                moving_loss_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.current_stream().wait_stream(moving_loss_stream)
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                stop_time = time()

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path,
                                          epoch, step)

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    dllogger.log(data=results, step=tuple())
def dist_evaluate(model, data_loader):
    """Test distributed DLRM model

    Args:
        model (DistDLRM):
        data_loader (torch.utils.data.DataLoader):
    """
    model.eval()

    device = FLAGS.base_device
    world_size = dist.get_world_size()

    batch_sizes_per_gpu = [
        FLAGS.test_batch_size // world_size for _ in range(world_size)
    ]
    test_batch_size = sum(batch_sizes_per_gpu)

    if FLAGS.test_batch_size != test_batch_size:
        print(f"Rounded test_batch_size to {test_batch_size}")
    print(f"Batch sizes per GPU {batch_sizes_per_gpu}")

    # Test bach size could be big, make sure it prints
    default_print_freq = max(524288 * 100 // test_batch_size, 1)
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))

    with torch.no_grad():
        timer = utils.StepTimer()

        # ROC can be computed per batch and then compute AUC globally, but I don't have the code.
        # So pack all the outputs and labels together to compute AUC. y_true and y_score naming follows sklearn
        y_true = []
        y_score = []
        data_stream = torch.cuda.Stream()

        batch_iter = prefetcher(iter(data_loader), data_stream)

        timer.click()

        for step in range(len(data_loader)):
            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            last_batch_size = None
            if click.shape[0] != test_batch_size:  # last batch
                last_batch_size = click.shape[0]
                logging.warning("Pad the last test batch of size %d to %d",
                                last_batch_size, test_batch_size)
                padding_size = test_batch_size - last_batch_size

                if numerical_features is not None:
                    padding_numerical = torch.empty(
                        padding_size,
                        numerical_features.shape[1],
                        device=numerical_features.device,
                        dtype=numerical_features.dtype)
                    numerical_features = torch.cat(
                        (numerical_features, padding_numerical), dim=0)

                if categorical_features is not None:
                    padding_categorical = torch.ones(
                        padding_size,
                        categorical_features.shape[1],
                        device=categorical_features.device,
                        dtype=categorical_features.dtype)
                    categorical_features = torch.cat(
                        (categorical_features, padding_categorical), dim=0)

            output = model(numerical_features, categorical_features,
                           batch_sizes_per_gpu).squeeze()

            output_receive_buffer = torch.empty(test_batch_size, device=device)
            torch.distributed.all_gather(
                list(output_receive_buffer.split(batch_sizes_per_gpu)), output)
            if last_batch_size is not None:
                output_receive_buffer = output_receive_buffer[:last_batch_size]

            if FLAGS.auc_device == "CPU":
                click = click.cpu()
                output_receive_buffer = output_receive_buffer.cpu()

            y_true.append(click)
            y_score.append(output_receive_buffer)

            timer.click()

            if timer.measured is not None:
                metric_logger.update(step_time=timer.measured)
                if step % print_freq == 0 and step > 0:
                    metric_logger.print(
                        header=f"Test: [{step}/{steps_per_epoch}]")

        if is_main_process():
            auc = utils.roc_auc_score(
                torch.cat(y_true), torch.sigmoid(torch.cat(y_score).float()))
        else:
            auc = None

        torch.distributed.barrier()

    model.train()

    return auc
示例#6
0
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu)
    device = FLAGS.base_device

    feature_spec = load_feature_spec(FLAGS)

    cat_feature_count = len(get_embedding_sizes(feature_spec, None))
    validate_flags(cat_feature_count)

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

    FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size)

    feature_spec = load_feature_spec(FLAGS)
    world_embedding_sizes = get_embedding_sizes(feature_spec, max_table_size=FLAGS.max_table_size)
    world_categorical_feature_sizes = np.asarray(world_embedding_sizes)
    device_mapping = get_device_mapping(world_embedding_sizes, num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))  # todo what does this do

    # Embedding sizes for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[device_mapping['embedding'][rank]].tolist()
    num_numerical_features = feature_spec.get_number_of_numerical_features()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping['bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping,
                                                           feature_spec=feature_spec)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device
    )

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = FLAGS.lr
    else:
        embedding_model_parallel_lr = FLAGS.lr / world_size

    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = FLAGS.lr
    else:
        MLP_model_parallel_lr = FLAGS.lr / world_size

    data_parallel_lr = FLAGS.lr

    if is_main_process():
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr},
            {'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params': list(model.bottom_model.embeddings.parameters()),
        'lr': embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict()
    )

    checkpoint_loader = make_distributed_checkpoint_loader(device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    scaler = torch.cuda.amp.GradScaler(enabled=FLAGS.amp, growth_interval=int(1e9))

    def parallelize(model):
        if world_size <= 1:
            return model

        if use_gpu:
            model.top_model = parallel.DistributedDataParallel(model.top_model)
        else:  # Use other backend for CPU
            model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
        return model

    if FLAGS.mode == 'test':
        model = parallelize(model)
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return
    elif FLAGS.mode == 'inference_benchmark':
        if world_size > 1:
            raise ValueError('Inference benchmark only supports singleGPU mode.')

        results = {}

        if FLAGS.amp:
            # can use pure FP16 for inference
            model = model.half()

        for batch_size in FLAGS.inference_benchmark_batch_sizes:
            FLAGS.test_batch_size = batch_size
            _, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping, feature_spec=feature_spec)

            latencies = inference_benchmark(model=model, data_loader=data_loader_test,
                                            num_batches=FLAGS.inference_benchmark_steps,
                                            cuda_graphs=FLAGS.cuda_graphs)

            # drop the first 10 as a warmup
            latencies = latencies[10:]

            mean_latency = np.mean(latencies)
            mean_inference_throughput = batch_size / mean_latency
            subresult = {f'mean_inference_latency_batch_{batch_size}': mean_latency,
                         f'mean_inference_throughput_batch_{batch_size}': mean_inference_throughput}
            results.update(subresult)
        dllogger.log(data=results, step=tuple())
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process():
        logging.warning("Saving checkpoint without --bottom_features_ordered flag will result in "
                        "a device-order dependent model. Consider using --bottom_features_ordered "
                        "if you plan to load the checkpoint in different device configurations.")

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    # last one will be dropped in the training loop
    steps_per_epoch = len(data_loader_train) - 1
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 2

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.8f}'))
    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)

    lr_scheduler = utils.LearningRateScheduler(optimizers=[mlp_optimizer, embedding_optimizer],
                                               base_lrs=[mlp_lrs, embedding_lrs],
                                               warmup_steps=FLAGS.warmup_steps,
                                               warmup_factor=FLAGS.warmup_factor,
                                               decay_start_step=FLAGS.decay_start_step,
                                               decay_steps=FLAGS.decay_steps,
                                               decay_power=FLAGS.decay_power,
                                               end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    def zero_grad(model):
        if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
            model.zero_grad()
        else:
            # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
            for param_group in itertools.chain(embedding_optimizer.param_groups, mlp_optimizer.param_groups):
                for param in param_group['params']:
                    param.grad = None

    def forward_backward(model, *args):

        numerical_features, categorical_features, click = args
        with torch.cuda.amp.autocast(enabled=FLAGS.amp):
            output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze()
            loss = loss_fn(output, click[batch_indices[rank]: batch_indices[rank + 1]])

        scaler.scale(loss).backward()

        return loss

    def weight_update():
        if not FLAGS.freeze_mlps:
            if FLAGS.Adam_MLP_optimizer:
                scale_MLP_gradients(mlp_optimizer, world_size)
            scaler.step(mlp_optimizer)

        if not FLAGS.freeze_embeddings:
            if FLAGS.Adam_embedding_optimizer:
                scale_embeddings_gradients(embedding_optimizer, world_size)
            scaler.unscale_(embedding_optimizer)
            embedding_optimizer.step()

        scaler.update()

    trainer = CudaGraphWrapper(model, forward_backward, parallelize, zero_grad,
                               cuda_graphs=FLAGS.cuda_graphs)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(f"Reached max global steps of {FLAGS.max_steps}. Stopping.")
                break

            # One of the batches will be smaller because the dataset size
            # isn't necessarily a multiple of the batch size. #TODO isn't dropping here a change of behavior
            if click.shape[0] != FLAGS.batch_size:
                continue

            lr_scheduler.step()
            loss = trainer.train_step(numerical_features, categorical_features, click)

            # need to wait for the gradients before the weight update
            torch.cuda.current_stream().wait_stream(trainer.stream)
            weight_update()
            moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])

                eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
                metric_logger.print(header=f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")

                moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(trainer.model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                          f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. ")
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. ")

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step)

    results = {'best_auc': best_auc,
               'best_epoch': best_epoch,
               'average_train_throughput': avg_throughput}

    if is_main_process():
        dllogger.log(data=results, step=tuple())
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time',
        utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    for epoch in range(FLAGS.epochs):

        batch_iter = iter(data_loader_train)
        for step in range(len(data_loader_train)):
            timer.click()

            global_step = steps_per_epoch * epoch + step

            numerical_features, categorical_features, click = next(batch_iter)

            categorical_features = categorical_features.to(base_device).to(
                torch.long)
            numerical_features = numerical_features.to(base_device)
            click = click.to(base_device).to(torch.float32)

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    F"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            optimizer.zero_grad()
            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            loss_value = loss.item()

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if global_step < FLAGS.benchmark_warmup_steps:
                metric_logger.update(loss=loss_value,
                                     lr=optimizer.param_groups[0]["lr"])
            else:
                unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
                metric_logger.update(loss=loss_value / unscale_factor,
                                     step_time=timer.measured,
                                     lr=optimizer.param_groups[0]["lr"] *
                                     unscale_factor)

            if step % print_freq == 0 and step > 0:
                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (
                    global_step + 1
            ) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())