def _train_with_recompute(n_steps): """Trains a single large model with gradient checkpointing using tf.recompute_grad.""" img_dim, n_channels, batch_size = 256, 1, 4 x, y = _get_dummy_data(img_dim, n_channels, batch_size) # This model is the same model as _get_big_cnn_model but split into 3 parts. models = _get_split_cnn_model(img_dim, n_channels, num_partitions=3, blocks_per_partition=2) model1, model2, model3 = models # Apply gradient checkpointing to the submodels using tf.recompute_grad. model1_re = tf.recompute_grad(model1) model2_re = tf.recompute_grad(model2) model3_re = tf.recompute_grad(model3) optimizer = optimizers.SGD() tr_vars = (model1.trainable_variables + model2.trainable_variables + model3.trainable_variables) losses = [] for _ in range(n_steps): with tf.GradientTape() as tape: logits1 = model1_re(x) logits2 = model2_re(logits1) logits3 = model3_re(logits2) loss = _compute_loss(logits3, y) losses.append(loss) grads = tape.gradient(loss, tr_vars) # tr_vars optimizer.apply_gradients(zip(grads, tr_vars)) del grads return losses
def _TestVariablesGradient(self, inputs, test_model, vars_to_grad): """Returns gradients of `test_model` with respect to `vars_to_grad`.""" test_model_re = tf.recompute_grad(test_model) with tf.GradientTape(persistent=True) as tape: tape.watch(vars_to_grad) out_re = test_model_re(inputs) out = test_model(inputs) grads_re = tape.gradient(out_re, vars_to_grad) grads = tape.gradient(out, vars_to_grad) return grads_re, grads