コード例 #1
0
def time_mlp(mlp, input):
    """
    Args:
        mlp (ModuleList):
        input (Tensor):

    Returns:
        fwd_elapsed_time (float): FWD time in ms
        bwd_elapsed_time (float): BWD time in ms
    """
    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        mlp_output = mlp(input)
    stop_time = timer_stop()

    fwd_elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3

    grad = torch.rand_like(mlp_output)
    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        mlp_output.backward(grad, retain_graph=True)
    stop_time = timer_stop()

    bwd_elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    return fwd_elapsed_time, bwd_elapsed_time
コード例 #2
0
def time_interaction(interaction, bottom_mlp_output, embedding_outputs, batch_size):
    """
    Args:
        interaction (function):
        bottom_mlp_output (Tensor):
        embedding_outputs (list): Sequence of tensors
        batch_size (int):

    Returns:
        fwd_elapsed_time (float): FWD time in ms
        bwd_elapsed_time (float): BWD time in ms
    """
    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        interaction_output = interaction(bottom_mlp_output, embedding_outputs, batch_size)
    stop_time = timer_stop()
    fwd_elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3

    dummy_grad = torch.rand_like(interaction_output)
    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        interaction_output.backward(dummy_grad, retain_graph=True)
    stop_time = timer_stop()
    bwd_elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3

    return fwd_elapsed_time, bwd_elapsed_time
コード例 #3
0
def time_optimizer(optimizer):
    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        optimizer.step()
    stop_time = timer_stop()
    elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    return elapsed_time
コード例 #4
0
def time_embeddings(model, input):
    """

    Args:
        model (Dlrm):
        input (Tensor): with shape [num_categorical_features, batch_size]

    Returns:
        fwd_elapsed_time (float): FWD time in ms
        bwd_elapsed_time (float): BWD time in ms
    """
    # Put indices on the same device as corresponding embedding
    device_indices = []
    if not FLAGS.joint_embedding:
        for embedding_id, embedding in enumerate(model.embeddings):
            device_indices.append(input[embedding_id].to(model._embedding_device_map[embedding_id]))
    else:
        device_indices.append(input.t())

    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        for embedding_id, embedding in enumerate(model.embeddings):
            embedding(device_indices[embedding_id]).to(FLAGS.base_device)
    stop_time = timer_stop()
    fwd_elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3

    # Run a untimed path to collect output of embeddings
    model.zero_grad()
    embedding_outputs = []
    for embedding_id, embedding in enumerate(model.embeddings):
        embedding_outputs.append(embedding(device_indices[embedding_id]).to(FLAGS.base_device))

    concat_output = torch.cat(embedding_outputs)
    grad = torch.rand_like(concat_output)

    logging.info("Backward of embedding seems to be pure memcpyD2D.")
    bwd_elapsed_time = 0
    for _ in range(FLAGS.num_iters):
        start_time = timer_start()
        concat_output.backward(grad, retain_graph=True)
        stop_time = timer_stop()
        model.zero_grad()  # Sparse gradient will keep aggregating if not cleared
        bwd_elapsed_time += (stop_time - start_time) * 1e3
    bwd_elapsed_time /= FLAGS.num_iters

    return fwd_elapsed_time, bwd_elapsed_time
コード例 #5
0
def main(argv):
    rank, world_size, gpu = dist.init_distributed_mode()

    top_mlp = create_top_mlp().to("cuda")
    print(top_mlp)

    optimizer = torch.optim.SGD(top_mlp.parameters(), lr=1.)

    if FLAGS.fp16:
        top_mlp, optimizer = amp.initialize(top_mlp,
                                            optimizer,
                                            opt_level="O1",
                                            loss_scale=1)

    if world_size > 1:
        top_mlp = parallel.DistributedDataParallel(top_mlp)
        model_without_ddp = top_mlp.module

    dummy_bottom_mlp_output = torch.rand(FLAGS.batch_size,
                                         EMBED_DIM,
                                         device="cuda")
    dummy_embedding_output = torch.rand(FLAGS.batch_size,
                                        26 * EMBED_DIM,
                                        device="cuda")
    dummy_target = torch.ones(FLAGS.batch_size, device="cuda")

    if FLAGS.fp16:
        dummy_bottom_mlp_output = dummy_bottom_mlp_output.to(torch.half)
        dummy_embedding_output = dummy_embedding_output.to(torch.half)

    # warm up GPU
    for _ in range(100):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out)

    start_time = utils.timer_start()
    for _ in range(FLAGS.num_iters):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out).squeeze()
        dummy_loss = output.mean()
        optimizer.zero_grad()
        if FLAGS.fp16:
            with amp.scale_loss(dummy_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            dummy_loss.backward()
        optimizer.step()
    stop_time = utils.timer_stop()

    elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    print(F"Average step time: {elapsed_time:.4f} ms.")
コード例 #6
0
def time_loss(loss_fn):
    dummy_out = torch.rand(FLAGS.batch_size, device=FLAGS.base_device, requires_grad=True)
    dummy_label = torch.rand(FLAGS.batch_size, device=FLAGS.base_device)
    if FLAGS.fp16:
        dummy_out = dummy_out.half()
        dummy_label = dummy_label.half()

    start_time = timer_start()
    for _ in range(FLAGS.num_iters):
        loss = loss_fn(dummy_out, dummy_label)
        loss.backward()
    stop_time = timer_stop()

    elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    return elapsed_time