def test_compare_mnist_train(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" self._create_pserver(model_def, 2) db, test_db = get_mnist_dataset(self._batch_size) stop_step = 20 self._create_worker(1) worker_results = self._worker_train(0, train_db=db, test_db=test_db, stop_step=stop_step) tf.keras.backend.clear_session() tf.random.set_seed(22) acc_meter = tf.keras.metrics.Accuracy() ( model, dataset_fn, loss_fn, opt_fn, eval_metrics_fn, prediction_outputs_processor, create_data_reader_fn, callbacks_list, ) = get_model_spec( model_zoo=self._model_zoo_path, model_def=model_def, dataset_fn="dataset_fn", model_params=None, loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", prediction_outputs_processor="PredictionOutputsProcessor", custom_data_reader="custom_data_reader", callbacks="callbacks", ) local_results = [] for step, (x, y) in enumerate(db): with tf.GradientTape() as tape: out = model.call(x, training=True) ll = loss_fn(y, out) grads = tape.gradient(ll, model.trainable_variables) opt_fn().apply_gradients(zip(grads, model.trainable_variables)) if step % 20 == 0: for (x, y) in test_db: out = model.call(x, training=False) acc_meter.update_state(tf.argmax(out, axis=1), y) local_results.append( (float(ll.numpy()), float(acc_meter.result().numpy()))) acc_meter.reset_states() if step > stop_step: break for w, l in zip(worker_results, local_results): self.assertTupleEqual(w, l)
def test_collect_gradients_with_allreduce_failure_case(self): worker = self._workers[1] train_db, _ = get_mnist_dataset(self._batch_size) for step, (x, y) in enumerate(train_db): if step == 0: worker._run_model_call_before_training(x) if step == self._test_steps: break self.assertEqual( worker._calculate_grads_and_report_with_allreduce(None), False, "Should fail when no data is received", )
def test_collect_gradients_with_allreduce_success_case(self): worker = self._workers[0] train_db, _ = get_mnist_dataset(self._batch_size) for step, (x, y) in enumerate(train_db): if step == 0: worker._run_model_call_before_training(x) w_loss, w_grads = worker.training_process_eagerly(x, y) if step == self._test_steps: break self.assertEqual( worker._collect_gradients_with_allreduce_robust(w_grads), True, )