def test_restart_ps(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" num_data = 8 training_data = [ get_random_batch(self._batch_size) for _ in range(num_data) ] workers = [] self._create_pserver(model_def, 2) for w in range(2): self._reset_pserver() arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) tf.keras.backend.clear_session() tf.random.set_seed(22) worker = Worker(args, ps_channels=self._channels) workers.append(worker) worker._run_model_call_before_training(training_data[0][0]) for i in range(num_data): worker.get_model() w_loss, w_grads = worker.training_process_eagerly( training_data[i][0], training_data[i][1] ) worker.report_gradient(w_grads) if w == 1 and i == 3: # Restart ps for the 2nd worker at i==3 # self._restart_pserver(model_def) self._reset_pserver() # `report_variable` will be called in `get_model` to # initialize variables on ps with worker variables worker.get_model() # send the grads again as these grads are not applied # on worker variables worker.report_gradient(w_grads) for var_name in workers[0]._non_embed_vars: np.testing.assert_array_equal( workers[0]._non_embed_vars[var_name].numpy(), workers[1]._non_embed_vars[var_name].numpy(), )
def test_compare_onebatch_train(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" self._create_pserver(model_def, 2) images, labels = get_random_batch(self._batch_size) # TODO(yunjian.lmh): test optimizer wrapper arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) tf.keras.backend.clear_session() tf.random.set_seed(22) worker = Worker(args, ps_channels=self._channels) worker._run_model_call_before_training(images) worker.get_model() w_loss, w_grads = worker.training_process_eagerly(images, labels) worker.report_gradient(w_grads) tf.keras.backend.clear_session() tf.random.set_seed(22) ( model, dataset_fn, loss_fn, opt_fn, eval_metrics_fn, prediction_outputs_processor, create_data_reader_fn, callback_list, ) = get_model_spec( model_zoo=self._model_zoo_path, model_def=model_def, dataset_fn="dataset_fn", model_params=None, loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", prediction_outputs_processor="PredictionOutputsProcessor", custom_data_reader="custom_data_reader", callbacks="callbacks", ) with tf.GradientTape() as tape: output = model.call(images, training=True) labels = tf.reshape(labels, [-1]) loss = loss_fn(labels, output) grads = tape.gradient(loss, model.trainable_variables) opt_fn().apply_gradients(zip(grads, model.trainable_variables)) for v in model.trainable_variables: ps_id = string_to_id(v.name, len(self._channels)) ps_v = self._pservers[ps_id].parameters.get_non_embedding_param( v.name) np.testing.assert_array_equal(ps_v.numpy(), v.numpy())
def _worker_train(self, train_db, test_db, dataset, stop_step): if dataset == "mnist": model_def = ( "mnist_functional_api.mnist_functional_api.custom_model" ) elif dataset == "frappe": model_def = ( "deepfm_functional_api.deepfm_functional_api.custom_model" ) else: raise ValueError("dataset %s is not supported", dataset) arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", "ParameterServerStrategy", ] args = parse_worker_args(arguments) worker = Worker(args, ps_channels=self._channel) acc_meter = tf.keras.metrics.Accuracy() worker_results = [] for step, (x, y) in enumerate(train_db): if step == 0: worker._run_model_call_before_training(x) worker.get_model(step, elasticdl_pb2.MINIMUM) w_loss, w_grads = worker.training_process_eagerly(x, y) worker.report_gradient(w_grads) if step % 20 == 0: worker.get_model(step, elasticdl_pb2.MINIMUM) for (x, y) in test_db: out = worker.forward_process(x) if dataset == "mnist": acc_meter.update_state(tf.argmax(out, axis=1), y) else: out["probs"] = tf.reshape(out["probs"], [-1]) acc_meter.update_state( tf.where( out["probs"] < 0.5, x=tf.zeros_like(y), y=tf.ones_like(y), ), y, ) worker_results.append( (float(w_loss.numpy()), float(acc_meter.result().numpy())) ) acc_meter.reset_states() if step > stop_step: break return worker_results