예제 #1
0
    def test_compare_mnist_train(self):
        model_def = "mnist_functional_api.mnist_functional_api.custom_model"
        self._create_pserver(model_def, 2)
        db, test_db = get_mnist_dataset(self._batch_size)
        stop_step = 20

        self._create_worker(1)
        worker_results = self._worker_train(0,
                                            train_db=db,
                                            test_db=test_db,
                                            stop_step=stop_step)

        tf.keras.backend.clear_session()
        tf.random.set_seed(22)

        acc_meter = tf.keras.metrics.Accuracy()

        (
            model,
            dataset_fn,
            loss_fn,
            opt_fn,
            eval_metrics_fn,
            prediction_outputs_processor,
            create_data_reader_fn,
            callbacks_list,
        ) = get_model_spec(
            model_zoo=self._model_zoo_path,
            model_def=model_def,
            dataset_fn="dataset_fn",
            model_params=None,
            loss="loss",
            optimizer="optimizer",
            eval_metrics_fn="eval_metrics_fn",
            prediction_outputs_processor="PredictionOutputsProcessor",
            custom_data_reader="custom_data_reader",
            callbacks="callbacks",
        )
        local_results = []
        for step, (x, y) in enumerate(db):
            with tf.GradientTape() as tape:
                out = model.call(x, training=True)
                ll = loss_fn(y, out)
            grads = tape.gradient(ll, model.trainable_variables)
            opt_fn().apply_gradients(zip(grads, model.trainable_variables))

            if step % 20 == 0:
                for (x, y) in test_db:
                    out = model.call(x, training=False)
                    acc_meter.update_state(tf.argmax(out, axis=1), y)

                local_results.append(
                    (float(ll.numpy()), float(acc_meter.result().numpy())))
                acc_meter.reset_states()

            if step > stop_step:
                break

        for w, l in zip(worker_results, local_results):
            self.assertTupleEqual(w, l)
 def test_collect_gradients_with_allreduce_failure_case(self):
     worker = self._workers[1]
     train_db, _ = get_mnist_dataset(self._batch_size)
     for step, (x, y) in enumerate(train_db):
         if step == 0:
             worker._run_model_call_before_training(x)
         if step == self._test_steps:
             break
         self.assertEqual(
             worker._calculate_grads_and_report_with_allreduce(None),
             False,
             "Should fail when no data is received",
         )
 def test_collect_gradients_with_allreduce_success_case(self):
     worker = self._workers[0]
     train_db, _ = get_mnist_dataset(self._batch_size)
     for step, (x, y) in enumerate(train_db):
         if step == 0:
             worker._run_model_call_before_training(x)
         w_loss, w_grads = worker.training_process_eagerly(x, y)
         if step == self._test_steps:
             break
         self.assertEqual(
             worker._collect_gradients_with_allreduce_robust(w_grads),
             True,
         )