def test_restart_ps(self): model_def = "mnist.mnist_functional_api.custom_model" num_data = 8 training_data = [ get_random_batch(self._batch_size) for _ in range(num_data) ] workers = [] self._create_pserver(model_def, 2) for w in range(2): self._reset_pserver() arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) tf.keras.backend.clear_session() tf.random.set_seed(22) worker = Worker(args, ps_client=PSClient(self._channels)) workers.append(worker) worker._trainer._run_model_call_before_training( training_data[0][0]) for i in range(num_data): worker._trainer._get_model() w_loss, w_grads = worker._trainer._training_process_eagerly( training_data[i][0], training_data[i][1]) worker._trainer._report_gradient(w_grads) if w == 1 and i == 3: # Restart ps for the 2nd worker at i==3 # self._restart_pserver(model_def) self._reset_pserver() # `push_dense_parameters` will be called in `get_model` to # initialize variables on ps with worker variables worker._trainer._get_model() # send the grads again as these grads are not applied # on worker variables worker._trainer._report_gradient(w_grads) for var_name in workers[0]._trainer._non_embed_vars: np.testing.assert_array_equal( workers[0]._trainer._non_embed_vars[var_name].numpy(), workers[1]._trainer._non_embed_vars[var_name].numpy(), ) self._close_channels()
def test_compare_onebatch_train(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" self._create_pserver(model_def, 2) images, labels = get_random_batch(self._batch_size) # TODO(yunjian.lmh): test optimizer wrapper arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) tf.keras.backend.clear_session() tf.random.set_seed(22) worker = Worker(args, ps_channels=self._channels) worker._run_model_call_before_training(images) worker.get_model() w_loss, w_grads = worker.training_process_eagerly(images, labels) worker.report_gradient(w_grads) tf.keras.backend.clear_session() tf.random.set_seed(22) ( model, dataset_fn, loss_fn, opt_fn, eval_metrics_fn, prediction_outputs_processor, create_data_reader_fn, callback_list, ) = get_model_spec( model_zoo=self._model_zoo_path, model_def=model_def, dataset_fn="dataset_fn", model_params=None, loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", prediction_outputs_processor="PredictionOutputsProcessor", custom_data_reader="custom_data_reader", callbacks="callbacks", ) with tf.GradientTape() as tape: output = model.call(images, training=True) labels = tf.reshape(labels, [-1]) loss = loss_fn(labels, output) grads = tape.gradient(loss, model.trainable_variables) opt_fn().apply_gradients(zip(grads, model.trainable_variables)) for v in model.trainable_variables: ps_id = string_to_id(v.name, len(self._channels)) ps_v = self._pservers[ps_id].parameters.get_non_embedding_param( v.name) np.testing.assert_array_equal(ps_v.numpy(), v.numpy())