예제 #1
0
def training_step(images, labels, batch):
    """
    Runs one training step, with considerations for the distributed nature of the job.
    :param images: images, sorted
    :param labels: labels for images, same order
    :param batch: batch number. variables must be broadcasted on first batch
    :return: loss for this step across all workers
    """
    with tf.GradientTape() as tape:
        # record losses for automatic differentiation (learning)
        probs = model(images, training=True)
        loss_value = loss(labels, probs)

    # need to wrap tape with more tape, because it is distributed
    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(
        grads, model.trainable_variables))  # learn on this worker

    if batch == 0:
        # workers share learning broadcast by broadcasting variables. one set of variables across workers.
        dist.broadcast_variables(model.variables, root_rank=0)
        dist.broadcast_variables(optimizer.variables(), root_rank=0)

    # all loss values from all workers reduced to one loss value
    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
예제 #2
0
def allreduce(model, optimizer, gradient_accumulator, loss, mlm_loss, mlm_acc,
              sop_loss, sop_acc):
    scaled_grads = gradient_accumulator.gradients
    grads = optimizer.get_unscaled_gradients(scaled_grads)
    # This, which is equivalent to sparse_as_dense=True, gives a mild 2% speedup from 0.62 it/s to 0.63 it/s
    # on BERT-large multinode.
    grads = [
        tf.convert_to_tensor(grad)
        if grad is not None and isinstance(grad, tf.IndexedSlices) else grad
        for grad in grads
    ]

    # TODO: Does placing this clip before or after allreduce affect accuracy?
    # Placing before has a regularization effect, no single example can contribute as much.
    # Placing before also gives a 20% speedup when training BERT-large, probably because the
    # gradient operations can be fused by XLA.
    (grads, grad_norm) = tf.clip_by_global_norm(grads, clip_norm=max_grad_norm)

    weight_norm = tf.math.sqrt(
        tf.math.reduce_sum(
            [tf.norm(var, ord=2)**2 for var in model.trainable_variables]))

    grads = [
        smddp.allreduce(grad,
                        param_index=idx,
                        num_params=len(grads),
                        compression=smddp.Compression.fp16)
        if grad is not None else None for idx, grad in enumerate(grads)
    ]

    optimizer.apply_gradients([
        (tf.cast(grad, var.dtype), var)
        for (grad, var) in zip(grads, model.trainable_variables)
        if grad is not None
    ])

    # Clear the gradient accumulator
    gradient_accumulator.reset()

    loss = smddp.oob_allreduce(loss)
    mlm_loss = smddp.oob_allreduce(mlm_loss)
    mlm_acc = smddp.oob_allreduce(mlm_acc)
    sop_loss = smddp.oob_allreduce(sop_loss)
    sop_acc = smddp.oob_allreduce(sop_acc)

    return loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            probs = model(images, training=True)
            loss_value = loss_fn(labels, probs)
            acc_value = acc(labels, probs)

        # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = dist.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # SMDataParallel: Broadcast model and optimizer variables
            dist.broadcast_variables(model.variables, root_rank=0)
            dist.broadcast_variables(opt.variables(), root_rank=0)

        # SMDataParallel: all_reduce call
        loss_value = dist.oob_allreduce(
            loss_value)  # Average the loss across workers
        acc_value = dist.oob_allreduce(acc_value)

        return loss_value, acc_value
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    tape = dist.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    if first_batch:
        dist.broadcast_variables(mnist_model.variables, root_rank=0)
        dist.broadcast_variables(opt.variables(), root_rank=0)

    loss_value = dist.oob_allreduce(
        loss_value)  # Average the loss across workers
    return loss_value
예제 #5
0
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            train_pred = model(images, training=True)
            loss_value = loss(labels, train_pred)
        # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = smdp.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # Change: Broadcast model and optimizer variables
            smdp.broadcast_variables(model.variables, root_rank=0)
            smdp.broadcast_variables(opt.variables(), root_rank=0)

        # Change: all_reduce call
        train_loss_value = smdp.oob_allreduce(
            loss_value)  # Average the loss across workers

        train_loss(train_loss_value)
        train_accuracy(labels, train_pred)
        return