Exemple #1
0
    def test_restart_ps(self):
        model_def = "mnist_functional_api.mnist_functional_api.custom_model"
        num_data = 8
        training_data = [
            get_random_batch(self._batch_size) for _ in range(num_data)
        ]
        workers = []
        self._create_pserver(model_def, 2)
        for w in range(2):
            self._reset_pserver()
            arguments = [
                "--worker_id",
                0,
                "--job_type",
                elasticdl_pb2.TRAINING,
                "--minibatch_size",
                self._batch_size,
                "--model_zoo",
                self._model_zoo_path,
                "--model_def",
                model_def,
                "--distribution_strategy",
                DistributionStrategy.PARAMETER_SERVER,
            ]
            args = parse_worker_args(arguments)
            tf.keras.backend.clear_session()
            tf.random.set_seed(22)
            worker = Worker(args, ps_channels=self._channels)
            workers.append(worker)
            worker._run_model_call_before_training(training_data[0][0])
            for i in range(num_data):
                worker.get_model()
                w_loss, w_grads = worker.training_process_eagerly(
                    training_data[i][0], training_data[i][1]
                )
                worker.report_gradient(w_grads)
                if w == 1 and i == 3:
                    # Restart ps for the 2nd worker at i==3
                    # self._restart_pserver(model_def)
                    self._reset_pserver()
                    # `report_variable` will be called in `get_model` to
                    # initialize variables on ps with worker variables
                    worker.get_model()
                    # send the grads again as these grads are not applied
                    # on worker variables
                    worker.report_gradient(w_grads)

        for var_name in workers[0]._non_embed_vars:
            np.testing.assert_array_equal(
                workers[0]._non_embed_vars[var_name].numpy(),
                workers[1]._non_embed_vars[var_name].numpy(),
            )
Exemple #2
0
    def test_compare_onebatch_train(self):
        model_def = "mnist_functional_api.mnist_functional_api.custom_model"
        self._create_pserver(model_def, 2)
        images, labels = get_random_batch(self._batch_size)
        # TODO(yunjian.lmh): test optimizer wrapper
        arguments = [
            "--worker_id",
            0,
            "--job_type",
            elasticdl_pb2.TRAINING,
            "--minibatch_size",
            self._batch_size,
            "--model_zoo",
            self._model_zoo_path,
            "--model_def",
            model_def,
            "--distribution_strategy",
            DistributionStrategy.PARAMETER_SERVER,
        ]
        args = parse_worker_args(arguments)

        tf.keras.backend.clear_session()
        tf.random.set_seed(22)

        worker = Worker(args, ps_channels=self._channels)
        worker._run_model_call_before_training(images)
        worker.get_model()
        w_loss, w_grads = worker.training_process_eagerly(images, labels)
        worker.report_gradient(w_grads)

        tf.keras.backend.clear_session()
        tf.random.set_seed(22)

        (
            model,
            dataset_fn,
            loss_fn,
            opt_fn,
            eval_metrics_fn,
            prediction_outputs_processor,
            create_data_reader_fn,
            callback_list,
        ) = get_model_spec(
            model_zoo=self._model_zoo_path,
            model_def=model_def,
            dataset_fn="dataset_fn",
            model_params=None,
            loss="loss",
            optimizer="optimizer",
            eval_metrics_fn="eval_metrics_fn",
            prediction_outputs_processor="PredictionOutputsProcessor",
            custom_data_reader="custom_data_reader",
            callbacks="callbacks",
        )

        with tf.GradientTape() as tape:
            output = model.call(images, training=True)
            labels = tf.reshape(labels, [-1])
            loss = loss_fn(labels, output)
        grads = tape.gradient(loss, model.trainable_variables)
        opt_fn().apply_gradients(zip(grads, model.trainable_variables))

        for v in model.trainable_variables:
            ps_id = string_to_id(v.name, len(self._channels))
            ps_v = self._pservers[ps_id].parameters.get_non_embedding_param(
                v.name)
            np.testing.assert_array_equal(ps_v.numpy(), v.numpy())
Exemple #3
0
    def _worker_train(self, train_db, test_db, dataset, stop_step):
        if dataset == "mnist":
            model_def = (
                "mnist_functional_api.mnist_functional_api.custom_model"
            )
        elif dataset == "frappe":
            model_def = (
                "deepfm_functional_api.deepfm_functional_api.custom_model"
            )
        else:
            raise ValueError("dataset %s is not supported", dataset)
        arguments = [
            "--worker_id",
            0,
            "--job_type",
            elasticdl_pb2.TRAINING,
            "--minibatch_size",
            self._batch_size,
            "--model_zoo",
            self._model_zoo_path,
            "--model_def",
            model_def,
            "--distribution_strategy",
            "ParameterServerStrategy",
        ]
        args = parse_worker_args(arguments)

        worker = Worker(args, ps_channels=self._channel)
        acc_meter = tf.keras.metrics.Accuracy()
        worker_results = []
        for step, (x, y) in enumerate(train_db):
            if step == 0:
                worker._run_model_call_before_training(x)

            worker.get_model(step, elasticdl_pb2.MINIMUM)

            w_loss, w_grads = worker.training_process_eagerly(x, y)
            worker.report_gradient(w_grads)

            if step % 20 == 0:
                worker.get_model(step, elasticdl_pb2.MINIMUM)
                for (x, y) in test_db:
                    out = worker.forward_process(x)
                    if dataset == "mnist":
                        acc_meter.update_state(tf.argmax(out, axis=1), y)
                    else:
                        out["probs"] = tf.reshape(out["probs"], [-1])
                        acc_meter.update_state(
                            tf.where(
                                out["probs"] < 0.5,
                                x=tf.zeros_like(y),
                                y=tf.ones_like(y),
                            ),
                            y,
                        )
                worker_results.append(
                    (float(w_loss.numpy()), float(acc_meter.result().numpy()))
                )
                acc_meter.reset_states()

            if step > stop_step:
                break
        return worker_results