예제 #1
0
def main(argv):
    rank, world_size, gpu = dist.init_distributed_mode()

    top_mlp = create_top_mlp().to("cuda")
    print(top_mlp)

    optimizer = torch.optim.SGD(top_mlp.parameters(), lr=1.)

    if FLAGS.fp16:
        top_mlp, optimizer = amp.initialize(top_mlp,
                                            optimizer,
                                            opt_level="O1",
                                            loss_scale=1)

    if world_size > 1:
        top_mlp = parallel.DistributedDataParallel(top_mlp)
        model_without_ddp = top_mlp.module

    dummy_bottom_mlp_output = torch.rand(FLAGS.batch_size,
                                         EMBED_DIM,
                                         device="cuda")
    dummy_embedding_output = torch.rand(FLAGS.batch_size,
                                        26 * EMBED_DIM,
                                        device="cuda")
    dummy_target = torch.ones(FLAGS.batch_size, device="cuda")

    if FLAGS.fp16:
        dummy_bottom_mlp_output = dummy_bottom_mlp_output.to(torch.half)
        dummy_embedding_output = dummy_embedding_output.to(torch.half)

    # warm up GPU
    for _ in range(100):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out)

    start_time = utils.timer_start()
    for _ in range(FLAGS.num_iters):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out).squeeze()
        dummy_loss = output.mean()
        optimizer.zero_grad()
        if FLAGS.fp16:
            with amp.scale_loss(dummy_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            dummy_loss.backward()
        optimizer.step()
    stop_time = utils.timer_stop()

    elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    print(F"Average step time: {elapsed_time:.4f} ms.")
예제 #2
0
def main(argv):
    if FLAGS.seed is not None:
        torch.manual_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)

    # Initialize distributed mode
    use_gpu = "cpu" not in FLAGS.device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    if world_size == 1:
        raise NotImplementedError(
            "This file is only for distributed training.")

    mlperf_logger.mlperf_submission_log('dlrm')
    mlperf_logger.log_event(key=mlperf_logger.constants.SEED, value=FLAGS.seed)
    mlperf_logger.log_event(key=mlperf_logger.constants.GLOBAL_BATCH_SIZE,
                            value=FLAGS.batch_size)

    # Only print cmd args on rank 0
    if rank == 0:
        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    # Check arguments sanity
    if FLAGS.batch_size % world_size != 0:
        raise ValueError(
            F"Batch size {FLAGS.batch_size} is not divisible by world_size {world_size}."
        )
    if FLAGS.test_batch_size % world_size != 0:
        raise ValueError(
            F"Test batch size {FLAGS.test_batch_size} is not divisible by world_size {world_size}."
        )

    # Load config file, create sub config for each rank
    with open(FLAGS.model_config, "r") as f:
        config = json.loads(f.read())

    wolrd_categorical_feature_sizes = np.asarray(
        config.pop('categorical_feature_sizes'))
    device_mapping = dist_model.get_criteo_device_mapping(world_size)
    vectors_per_gpu = device_mapping['vectors_per_gpu']
    # Get sizes of embeddings each GPU is gonna create
    categorical_feature_sizes = wolrd_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = config.pop('bottom_mlp_sizes')
    if rank != device_mapping['bottom_mlp']:
        bottom_mlp_sizes = None

    model = dist_model.DistDlrm(
        categorical_feature_sizes=categorical_feature_sizes,
        bottom_mlp_sizes=bottom_mlp_sizes,
        world_num_categorical_features=len(wolrd_categorical_feature_sizes),
        **config,
        device=FLAGS.device,
        use_embedding_ext=FLAGS.use_embedding_ext)
    print(model)

    dist.setup_distributed_print(rank == 0)

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr
    scaled_lrs = [scaled_lr / world_size, scaled_lr]

    embedding_optimizer = torch.optim.SGD([
        {
            'params': model.bottom_model.joint_embedding.parameters(),
            'lr': scaled_lrs[0]
        },
    ])
    mlp_optimizer = apex_optim.FusedSGD([{
        'params':
        model.bottom_model.bottom_mlp.parameters(),
        'lr':
        scaled_lrs[0]
    }, {
        'params': model.top_model.parameters(),
        'lr': scaled_lrs[1]
    }])

    if FLAGS.fp16:
        (model.top_model,
         model.bottom_model.bottom_mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.bottom_mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1,
             cast_model_outputs=torch.float16)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Too many arguments to pass for distributed training. Use plain train code here instead of
    # defining a train function

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f} ms'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=FLAGS.device)
    moving_loss_stream = torch.cuda.Stream()

    local_embedding_device_mapping = torch.tensor(
        device_mapping['embedding'][rank],
        device=FLAGS.device,
        dtype=torch.long)

    # LR is logged twice for now because of a compliance checker bug
    mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR,
                            value=FLAGS.lr)
    mlperf_logger.log_event(key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS,
                            value=FLAGS.warmup_steps)

    # use logging keys from the official HP table and not from the logging library
    mlperf_logger.log_event(key='sgd_opt_base_learning_rate', value=FLAGS.lr)
    mlperf_logger.log_event(key='lr_decay_start_steps',
                            value=FLAGS.decay_start_step)
    mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_steps',
                            value=FLAGS.decay_steps)
    mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_poly_power',
                            value=FLAGS.decay_power)

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[scaled_lrs, [scaled_lrs[0]]],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    eval_data_cache = [] if FLAGS.cache_eval_data else None

    start_time = time()
    stop_time = time()

    print("Creating data loaders")
    dist_dataset_args = {
        "numerical_features": rank == 0,
        "categorical_features": device_mapping['embedding'][rank]
    }

    mlperf_logger.barrier()
    mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP)
    mlperf_logger.barrier()
    mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START)
    mlperf_logger.barrier()

    data_loader_train, data_loader_test = dataset.get_data_loader(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.test_batch_size,
        FLAGS.device,
        dataset_type=FLAGS.dataset_type,
        shuffle=FLAGS.shuffle,
        **dist_dataset_args)

    steps_per_epoch = len(data_loader_train)

    # Default 20 tests per epoch
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch // 20

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        mlperf_logger.barrier()
        mlperf_logger.log_start(key=mlperf_logger.constants.BLOCK_START,
                                metadata={
                                    mlperf_logger.constants.FIRST_EPOCH_NUM:
                                    epoch + 1,
                                    mlperf_logger.constants.EPOCH_COUNT:
                                    1
                                })
        mlperf_logger.barrier()
        mlperf_logger.log_start(
            key=mlperf_logger.constants.EPOCH_START,
            metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1})

        if FLAGS.profile_steps is not None:
            torch.cuda.profiler.start()
        for step, (numerical_features, categorical_features,
                   click) in enumerate(
                       dataset.prefetcher(iter(data_loader_train),
                                          data_stream)):
            torch.cuda.current_stream().wait_stream(data_stream)

            global_step = steps_per_epoch * epoch + step
            lr_scheduler.step()

            # Slice out categorical features if not using the "dist" dataset
            if FLAGS.dataset_type != "dist":
                categorical_features = categorical_features[:,
                                                            local_embedding_device_mapping]

            if FLAGS.fp16 and categorical_features is not None:
                numerical_features = numerical_features.to(torch.float16)

            last_batch_size = None
            if click.shape[0] != FLAGS.batch_size:  # last batch
                last_batch_size = click.shape[0]
                logging.debug("Pad the last batch of size %d to %d",
                              last_batch_size, FLAGS.batch_size)
                padding_size = FLAGS.batch_size - last_batch_size
                padding_numiercal = torch.empty(
                    padding_size,
                    numerical_features.shape[1],
                    device=numerical_features.device,
                    dtype=numerical_features.dtype)
                numerical_features = torch.cat(
                    (numerical_features, padding_numiercal), dim=0)
                if categorical_features is not None:
                    padding_categorical = torch.ones(
                        padding_size,
                        categorical_features.shape[1],
                        device=categorical_features.device,
                        dtype=categorical_features.dtype)
                    categorical_features = torch.cat(
                        (categorical_features, padding_categorical), dim=0)
                padding_click = torch.empty(padding_size,
                                            device=click.device,
                                            dtype=click.dtype)
                click = torch.cat((click, padding_click))

            bottom_out = model.bottom_model(numerical_features,
                                            categorical_features)

            batch_size_per_gpu = FLAGS.batch_size // world_size
            from_bottom = dist_model.bottom_to_top(bottom_out,
                                                   batch_size_per_gpu,
                                                   config['embedding_dim'],
                                                   vectors_per_gpu)

            if last_batch_size is not None:
                partial_rank = math.ceil(last_batch_size / batch_size_per_gpu)
                if rank == partial_rank:
                    top_out = model.top_model(
                        from_bottom[:last_batch_size %
                                    batch_size_per_gpu]).squeeze().float()
                    loss = loss_fn(
                        top_out, click[rank * batch_size_per_gpu:(rank + 1) *
                                       batch_size_per_gpu][:last_batch_size %
                                                           batch_size_per_gpu])
                elif rank < partial_rank:
                    loss = loss_fn(
                        model.top_model(from_bottom).squeeze().float(),
                        click[rank * batch_size_per_gpu:(rank + 1) *
                              batch_size_per_gpu])
                else:
                    # Back propgate nothing for padded samples
                    loss = 0. * model.top_model(
                        from_bottom).squeeze().float().mean()
            else:
                loss = loss_fn(
                    model.top_model(from_bottom).squeeze().float(),
                    click[rank * batch_size_per_gpu:(rank + 1) *
                          batch_size_per_gpu])

            # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
            for param_group in itertools.chain(
                    embedding_optimizer.param_groups,
                    mlp_optimizer.param_groups):
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            mlp_optimizer.step()
            embedding_optimizer.step()

            moving_loss_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(moving_loss_stream):
                moving_loss += loss
            if step == 0:
                print(F"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.synchronize()
                # Averaging cross a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.
                metric_logger.update(step_time=(time() - stop_time) * 1000 /
                                     print_freq,
                                     loss=moving_loss.item() / print_freq /
                                     (FLAGS.loss_scale if FLAGS.fp16 else 1),
                                     lr=mlp_optimizer.param_groups[1]["lr"] *
                                     (FLAGS.loss_scale if FLAGS.fp16 else 1))
                stop_time = time()
                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.avg / 1000 *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )
                moving_loss = 0.
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                mlperf_epoch_index = global_step / steps_per_epoch + 1

                mlperf_logger.barrier()
                mlperf_logger.log_start(key=mlperf_logger.constants.EVAL_START,
                                        metadata={
                                            mlperf_logger.constants.EPOCH_NUM:
                                            mlperf_epoch_index
                                        })
                auc = dist_evaluate(model, data_loader_test, eval_data_cache)
                mlperf_logger.log_event(
                    key=mlperf_logger.constants.EVAL_ACCURACY,
                    value=float(auc),
                    metadata={
                        mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index
                    })
                print(F"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()
                mlperf_logger.barrier()
                mlperf_logger.log_end(key=mlperf_logger.constants.EVAL_STOP,
                                      metadata={
                                          mlperf_logger.constants.EPOCH_NUM:
                                          mlperf_epoch_index
                                      })

                if auc > FLAGS.auc_threshold:
                    mlperf_logger.barrier()
                    mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP,
                                          metadata={
                                              mlperf_logger.constants.STATUS:
                                              mlperf_logger.constants.SUCCESS
                                          })

                    mlperf_logger.barrier()
                    mlperf_logger.log_end(
                        key=mlperf_logger.constants.EPOCH_STOP,
                        metadata={
                            mlperf_logger.constants.EPOCH_NUM: epoch + 1
                        })
                    mlperf_logger.barrier()
                    mlperf_logger.log_end(
                        key=mlperf_logger.constants.BLOCK_STOP,
                        metadata={
                            mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1
                        })

                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

            if FLAGS.profile_steps is not None and global_step == FLAGS.profile_steps:
                torch.cuda.profiler.stop()
                logging.warning("Profile run, stopped at step %d.",
                                global_step)
                return

        mlperf_logger.barrier()
        mlperf_logger.log_end(
            key=mlperf_logger.constants.EPOCH_STOP,
            metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1})
        mlperf_logger.barrier()
        mlperf_logger.log_end(
            key=mlperf_logger.constants.BLOCK_STOP,
            metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1})

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    mlperf_logger.barrier()
    mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP,
                          metadata={
                              mlperf_logger.constants.STATUS:
                              mlperf_logger.constants.ABORTED
                          })
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    device = FLAGS.base_device

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    print("Creating data loaders")

    FLAGS.set_default("test_batch_size",
                      FLAGS.test_batch_size // world_size * world_size)

    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
    device_mapping = get_device_mapping(categorical_feature_sizes,
                                        num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size,
                                              num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))

    # sizes of embeddings for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[
        'bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(
        FLAGS, device_mapping=device_mapping)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=FLAGS.num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device)
    print(model)
    print(device_mapping)
    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr

    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = scaled_lr
    else:
        embedding_model_parallel_lr = scaled_lr / world_size
    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = scaled_lr
    else:
        MLP_model_parallel_lr = scaled_lr / world_size
    data_parallel_lr = scaled_lr

    if is_main_process():
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }, {
            'params': list(model.bottom_model.mlp.parameters()),
            'lr': MLP_model_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params':
        list(model.bottom_model.embeddings.parameters()),
        'lr':
        embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict())

    checkpoint_loader = make_distributed_checkpoint_loader(
        device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    if FLAGS.amp:
        (model.top_model,
         model.bottom_model.mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    if FLAGS.mode == 'test':
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(
    ):
        logging.warning(
            "Saving checkpoint without --bottom_features_ordered flag will result in "
            "a device-order dependent model. Consider using --bottom_features_ordered "
            "if you plan to load the checkpoint in different device configurations."
        )

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader_train)
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)
    moving_loss_stream = torch.cuda.Stream()

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[mlp_lrs, embedding_lrs],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    stop_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            lr_scheduler.step()

            if click.shape[0] != FLAGS.batch_size:  # last batch
                logging.error("The last batch with size %s is not supported",
                              click.shape[0])
            else:
                output = model(numerical_features, categorical_features,
                               batch_sizes_per_gpu).squeeze()

                loss = loss_fn(
                    output, click[batch_indices[rank]:batch_indices[rank + 1]])

                if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
                    model.zero_grad()
                else:
                    # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
                    for param_group in itertools.chain(
                            embedding_optimizer.param_groups,
                            mlp_optimizer.param_groups):
                        for param in param_group['params']:
                            param.grad = None

                if FLAGS.amp:
                    loss *= FLAGS.loss_scale
                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if FLAGS.Adam_MLP_optimizer:
                    scale_MLP_gradients(mlp_optimizer, world_size)
                mlp_optimizer.step()

                if FLAGS.Adam_embedding_optimizer:
                    scale_embeddings_gradients(embedding_optimizer, world_size)
                embedding_optimizer.step()

                moving_loss_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.current_stream().wait_stream(moving_loss_stream)
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                stop_time = time()

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path,
                                          epoch, step)

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    dllogger.log(data=results, step=tuple())
예제 #4
0
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu)
    device = FLAGS.base_device

    feature_spec = load_feature_spec(FLAGS)

    cat_feature_count = len(get_embedding_sizes(feature_spec, None))
    validate_flags(cat_feature_count)

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

    FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size)

    feature_spec = load_feature_spec(FLAGS)
    world_embedding_sizes = get_embedding_sizes(feature_spec, max_table_size=FLAGS.max_table_size)
    world_categorical_feature_sizes = np.asarray(world_embedding_sizes)
    device_mapping = get_device_mapping(world_embedding_sizes, num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))  # todo what does this do

    # Embedding sizes for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[device_mapping['embedding'][rank]].tolist()
    num_numerical_features = feature_spec.get_number_of_numerical_features()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping['bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping,
                                                           feature_spec=feature_spec)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device
    )

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = FLAGS.lr
    else:
        embedding_model_parallel_lr = FLAGS.lr / world_size

    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = FLAGS.lr
    else:
        MLP_model_parallel_lr = FLAGS.lr / world_size

    data_parallel_lr = FLAGS.lr

    if is_main_process():
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr},
            {'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params': list(model.bottom_model.embeddings.parameters()),
        'lr': embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict()
    )

    checkpoint_loader = make_distributed_checkpoint_loader(device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    scaler = torch.cuda.amp.GradScaler(enabled=FLAGS.amp, growth_interval=int(1e9))

    def parallelize(model):
        if world_size <= 1:
            return model

        if use_gpu:
            model.top_model = parallel.DistributedDataParallel(model.top_model)
        else:  # Use other backend for CPU
            model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
        return model

    if FLAGS.mode == 'test':
        model = parallelize(model)
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return
    elif FLAGS.mode == 'inference_benchmark':
        if world_size > 1:
            raise ValueError('Inference benchmark only supports singleGPU mode.')

        results = {}

        if FLAGS.amp:
            # can use pure FP16 for inference
            model = model.half()

        for batch_size in FLAGS.inference_benchmark_batch_sizes:
            FLAGS.test_batch_size = batch_size
            _, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping, feature_spec=feature_spec)

            latencies = inference_benchmark(model=model, data_loader=data_loader_test,
                                            num_batches=FLAGS.inference_benchmark_steps,
                                            cuda_graphs=FLAGS.cuda_graphs)

            # drop the first 10 as a warmup
            latencies = latencies[10:]

            mean_latency = np.mean(latencies)
            mean_inference_throughput = batch_size / mean_latency
            subresult = {f'mean_inference_latency_batch_{batch_size}': mean_latency,
                         f'mean_inference_throughput_batch_{batch_size}': mean_inference_throughput}
            results.update(subresult)
        dllogger.log(data=results, step=tuple())
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process():
        logging.warning("Saving checkpoint without --bottom_features_ordered flag will result in "
                        "a device-order dependent model. Consider using --bottom_features_ordered "
                        "if you plan to load the checkpoint in different device configurations.")

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    # last one will be dropped in the training loop
    steps_per_epoch = len(data_loader_train) - 1
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 2

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.8f}'))
    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)

    lr_scheduler = utils.LearningRateScheduler(optimizers=[mlp_optimizer, embedding_optimizer],
                                               base_lrs=[mlp_lrs, embedding_lrs],
                                               warmup_steps=FLAGS.warmup_steps,
                                               warmup_factor=FLAGS.warmup_factor,
                                               decay_start_step=FLAGS.decay_start_step,
                                               decay_steps=FLAGS.decay_steps,
                                               decay_power=FLAGS.decay_power,
                                               end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    def zero_grad(model):
        if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
            model.zero_grad()
        else:
            # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
            for param_group in itertools.chain(embedding_optimizer.param_groups, mlp_optimizer.param_groups):
                for param in param_group['params']:
                    param.grad = None

    def forward_backward(model, *args):

        numerical_features, categorical_features, click = args
        with torch.cuda.amp.autocast(enabled=FLAGS.amp):
            output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze()
            loss = loss_fn(output, click[batch_indices[rank]: batch_indices[rank + 1]])

        scaler.scale(loss).backward()

        return loss

    def weight_update():
        if not FLAGS.freeze_mlps:
            if FLAGS.Adam_MLP_optimizer:
                scale_MLP_gradients(mlp_optimizer, world_size)
            scaler.step(mlp_optimizer)

        if not FLAGS.freeze_embeddings:
            if FLAGS.Adam_embedding_optimizer:
                scale_embeddings_gradients(embedding_optimizer, world_size)
            scaler.unscale_(embedding_optimizer)
            embedding_optimizer.step()

        scaler.update()

    trainer = CudaGraphWrapper(model, forward_backward, parallelize, zero_grad,
                               cuda_graphs=FLAGS.cuda_graphs)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(f"Reached max global steps of {FLAGS.max_steps}. Stopping.")
                break

            # One of the batches will be smaller because the dataset size
            # isn't necessarily a multiple of the batch size. #TODO isn't dropping here a change of behavior
            if click.shape[0] != FLAGS.batch_size:
                continue

            lr_scheduler.step()
            loss = trainer.train_step(numerical_features, categorical_features, click)

            # need to wait for the gradients before the weight update
            torch.cuda.current_stream().wait_stream(trainer.stream)
            weight_update()
            moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])

                eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
                metric_logger.print(header=f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")

                moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(trainer.model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                          f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. ")
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. ")

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step)

    results = {'best_auc': best_auc,
               'best_epoch': best_epoch,
               'average_train_throughput': avg_throughput}

    if is_main_process():
        dllogger.log(data=results, step=tuple())