def _train_with_recompute(n_steps):
    """Trains a single large model with gradient checkpointing using tf.recompute_grad."""
    img_dim, n_channels, batch_size = 256, 1, 4
    x, y = _get_dummy_data(img_dim, n_channels, batch_size)
    # This model is the same model as _get_big_cnn_model but split into 3 parts.
    models = _get_split_cnn_model(img_dim,
                                  n_channels,
                                  num_partitions=3,
                                  blocks_per_partition=2)
    model1, model2, model3 = models
    # Apply gradient checkpointing to the submodels using tf.recompute_grad.
    model1_re = tf.recompute_grad(model1)
    model2_re = tf.recompute_grad(model2)
    model3_re = tf.recompute_grad(model3)
    optimizer = optimizers.SGD()
    tr_vars = (model1.trainable_variables + model2.trainable_variables +
               model3.trainable_variables)
    losses = []
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits1 = model1_re(x)
            logits2 = model2_re(logits1)
            logits3 = model3_re(logits2)
            loss = _compute_loss(logits3, y)
            losses.append(loss)
            grads = tape.gradient(loss, tr_vars)  # tr_vars
            optimizer.apply_gradients(zip(grads, tr_vars))
            del grads
    return losses
Beispiel #2
0
    def _TestVariablesGradient(self, inputs, test_model, vars_to_grad):
        """Returns gradients of `test_model` with respect to `vars_to_grad`."""

        test_model_re = tf.recompute_grad(test_model)

        with tf.GradientTape(persistent=True) as tape:
            tape.watch(vars_to_grad)
            out_re = test_model_re(inputs)
            out = test_model(inputs)

        grads_re = tape.gradient(out_re, vars_to_grad)
        grads = tape.gradient(out, vars_to_grad)

        return grads_re, grads