예제 #1
0
def training_step(images, labels, batch):
    """
    Runs one training step, with considerations for the distributed nature of the job.
    :param images: images, sorted
    :param labels: labels for images, same order
    :param batch: batch number. variables must be broadcasted on first batch
    :return: loss for this step across all workers
    """
    with tf.GradientTape() as tape:
        # record losses for automatic differentiation (learning)
        probs = model(images, training=True)
        loss_value = loss(labels, probs)

    # need to wrap tape with more tape, because it is distributed
    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(
        grads, model.trainable_variables))  # learn on this worker

    if batch == 0:
        # workers share learning broadcast by broadcasting variables. one set of variables across workers.
        dist.broadcast_variables(model.variables, root_rank=0)
        dist.broadcast_variables(optimizer.variables(), root_rank=0)

    # all loss values from all workers reduced to one loss value
    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
예제 #2
0
def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None):
    pbar = tqdm(train_dataset)
    for i, batch in enumerate(pbar):
        with tf.GradientTape() as tape:
            inputs, targets = batch
            outputs = model(batch)
            loss_value = loss(targets, outputs.logits)

        if SDP_ENABLED:
            tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        pbar.set_description(f"Loss: {loss_value:.4f}")

        if SDP_ENABLED:
            if i == 0:
                sdp.broadcast_variables(model.variables, root_rank=0)
                sdp.broadcast_variables(opt.variables(), root_rank=0)
                first_batch = False

        if max_steps and i >= max_steps:
            break

    train_results = {"loss": loss_value.numpy()}
    return train_results
예제 #3
0
def training_step(mnist_model, loss, opt, images, labels, batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    ########################################################
    ####### 4. SageMaker Distributed Data Parallel  ########
    #######  - Optimize AllReduce operation         ########
    ########################################################

    # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
    tape = smdp.DistributedGradientTape(tape)

    #######################################################

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    ########################################################
    ####### 5. SageMaker Distributed Data Parallel   #######
    #######  - Broadcast the initial model variables #######
    #######    from rank 0 to ranks 1 ~ n            #######
    ########################################################

    if batch == 0:
        # SMDataParallel: Broadcast model and optimizer variables
        smdp.broadcast_variables(mnist_model.variables, root_rank=0)
        smdp.broadcast_variables(opt.variables(), root_rank=0)

    #######################################################

    return loss_value
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    if first_batch:
        dist.broadcast_variables(mnist_model.variables, root_rank=0)
        dist.broadcast_variables(opt.variables(), root_rank=0)

    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
예제 #5
0
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    # Create a new DistributedGradientTape, which uses TensorFlow’s GradientTape under the hood,
    # using an AllReduce to combine gradient values before applying gradients to model weights.
    tape = smdataparallel.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    # Broadcast model and optimizer variable are first forward pass for sync
    if first_batch:
        smdataparallel.broadcast_variables(mnist_model.variables, root_rank=0)
        smdataparallel.broadcast_variables(opt.variables(), root_rank=0)

    return loss_value
예제 #6
0
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    if first_batch:
        # SMDataParallel: Broadcast model and optimizer variables
        dist.broadcast_variables(mnist_model.variables, root_rank=0)
        dist.broadcast_variables(opt.variables(), root_rank=0)

    # SMDataParallel: all_reduce call
    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
예제 #7
0
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            train_pred = model(images, training=True)
            loss_value = loss(labels, train_pred)
        # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = smdp.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # Change: Broadcast model and optimizer variables
            smdp.broadcast_variables(model.variables, root_rank=0)
            smdp.broadcast_variables(opt.variables(), root_rank=0)

        # Change: all_reduce call
        train_loss_value = smdp.oob_allreduce(
            loss_value)  # Average the loss across workers

        train_loss(train_loss_value)
        train_accuracy(labels, train_pred)
        return