Beispiel #1
0
 def stage2(lr, partial, labels):
     logits = keras.layers.Dense(10)(partial)
     per_example_loss = keras.losses.sparse_categorical_crossentropy(
         y_true=labels, y_pred=logits, from_logits=True)
     # In a custom training loop, the optimiser does an allreduce *sum*, not
     # average, of the gradients across the distributed workers. Therefore
     # we want to divide the loss here by the *global* batch size, which is
     # done by the `tf.nn.compute_average_loss()` function.
     loss = nn.compute_average_loss(per_example_loss)
     return lr, loss
Beispiel #2
0
        def step_fn(inputs):
          """The computation to run on each worker."""
          features, labels = inputs
          with backprop.GradientTape() as tape:
            pred = model(features, training=True)
            loss = keras.losses.binary_crossentropy(labels, pred)
            loss = nn.compute_average_loss(loss)
          grads = tape.gradient(loss, model.trainable_variables)
          optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))

          actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
          accuracy.update_state(labels, actual_pred)
Beispiel #3
0
      def replica_fn(iterator):
        batch_data, labels = next(iterator)
        with backprop.GradientTape() as tape:
          pred = model(batch_data, training=True)
          loss = nn.compute_average_loss(
              keras.losses.BinaryCrossentropy(
                  reduction=losses_utils.ReductionV2.NONE)(labels, pred))
          gradients = tape.gradient(loss, model.trainable_variables)

        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
        accuracy.update_state(labels, actual_pred)
def compute_loss(labels, logits, reg_losses):
    pred_loss = keras.losses.mean_squared_error(labels, logits)
    scaled_loss = nn.compute_average_loss(pred_loss,
                                          global_batch_size=_BATCH_SIZE)
    l2_loss = nn.scale_regularization_loss(reg_losses)
    return scaled_loss + l2_loss