def _update_model(self): assert self._lock.locked() grad_var = [] # (grad, var) pairs excluding keras Embedding layer and # ElasticDL Embedding layer for k in self._gradient_sum: self._gradient_sum[k] = self._gradient_sum[k] / self._grad_to_wait grad_var.append((self._gradient_sum[k], self._model[k])) # (grad, var) pair of Keras Embedding layer for k in self._gradient_sum_indexed: grad_var.append((self._gradient_sum_indexed[k], self._model[k])) # (grad, var) pair of ElasticDL Embedding layer edl_embedding_offset = len(grad_var) unique_ids_list = [] if self._edl_embedding_gradients: for layer_name, grads in self._edl_embedding_gradients.items(): unique_ids, idx = tf.unique(grads.indices) unique_ids_list.append(unique_ids) grads_idx_transformed = tf.IndexedSlices(grads.values, idx) keys = [ Embedding.get_key([layer_name, i]) for i in unique_ids.numpy() ] embeddings, unknown_keys = EmbeddingService.lookup_embedding( embedding_service_endpoint=( self._embedding_service_endpoint), keys=keys, ) if unknown_keys: raise RuntimeError( "Master reviced %d unknown embedding keys: %s ..." % (len(unknown_keys), str(unknown_keys[0]))) if not embeddings: continue embeddings = np.concatenate(embeddings, axis=0).reshape(len(keys), -1) embedding_var = tf.Variable(embeddings) grad_var.append((grads_idx_transformed, embedding_var)) # TODO: support optimizer with slots such as Adam, FTRL self._opt.apply_gradients(grad_var) # report updated embedding table to EmbeddingService self._update_edl_embedding_table( zip( self._edl_embedding_gradients.keys(), unique_ids_list, [v for g, v in grad_var[edl_embedding_offset:]], )) self._update_model_version() self._gradient_sum.clear() self._gradient_sum_indexed.clear() self._edl_embedding_gradients.clear() self._grad_n = 0
def _report_to_kv_store(self): """Report updated embedding vectors and slots to kv store.""" keys = [] values = [] for layer, ids in self._unique_ids_all_layers.items(): value = self._get_embedding_variable(layer).numpy() for id, v in zip(ids, value): keys.append(Embedding.get_key([layer, id])) values.append(v) for slot in self._allowed_slot_names: value = self._get_slot_variable(layer, slot).numpy() for id, v in zip(ids, value): keys.append(Embedding.get_key([layer, slot, id])) values.append(v) EmbeddingService.update_embedding(keys, values, self._kv_store_endpoint)
def test_embedding_service(self): with tempfile.TemporaryDirectory() as temp_dir: embedding_endpoint = start_redis_instances(temp_dir) # start embedding_service = EmbeddingService(embedding_endpoint) embedding_endpoint = embedding_service._create_redis_cluster() # wait for cluster up-running time.sleep(1) self.assertFalse(embedding_endpoint is None) # connection redis_cluster = embedding_service._get_embedding_cluster() self.assertFalse(redis_cluster is None) # set value to a key self.assertTrue(redis_cluster.set("test_key", "OK", nx=True)) # set value to a key existed self.assertTrue( redis_cluster.set("test_key", "OK", nx=True) is None) self.assertEqual(b"OK", redis_cluster.get("test_key")) # close self.assertTrue(embedding_service.stop_embedding_service())
def _update_edl_embedding_table(self, name_var_list): """ Put updated embedding vectors' ids and values together and use EmbeddingService.update_embedding() to update embedding table in the distributed storage """ keys = [] embeddings = [] for layer_name, unique_ids, embedding_var in name_var_list: keys.extend([ Embedding.get_key([layer_name, i]) for i in unique_ids.numpy() ]) embeddings.extend([i for i in embedding_var.numpy()]) if embeddings: EmbeddingService.update_embedding( keys=keys, embedding_vectors=embeddings, embedding_service_endpoint=self._embedding_service_endpoint, )
def lookup_embedding(self, ids, layer_name, initializer="uniform", embedding_table_dim=128): keys = [Embedding.get_key([layer_name, id]) for id in ids] ES_lookup_embedding = EmbeddingService.lookup_embedding embedding_vectors, unknown_keys_index = ES_lookup_embedding( keys=keys, embedding_service_endpoint=self._embedding_service_endpoint, ) if unknown_keys_index: # Initialize unknown_keys' embedding vectors and write into Redis. unknown_keys = [keys[index] for index in unknown_keys_index] initializer = tf.keras.initializers.get(initializer) embedding_vector_init = [ initializer(shape=[1, embedding_table_dim]).numpy() for _ in unknown_keys ] embedding_vector_init = np.concatenate(embedding_vector_init, axis=0) EmbeddingService.update_embedding( keys=unknown_keys, embedding_vectors=embedding_vector_init, embedding_service_endpoint=self._embedding_service_endpoint, set_if_not_exist=True, ) # Lookup unknown_keys' embedding vectors embedding_vectors_new, unknown_keys_idx_new = ES_lookup_embedding( keys=unknown_keys, embedding_service_endpoint=self._embedding_service_endpoint, ) if unknown_keys_idx_new: raise Exception("Update embedding vector: %s failed." % str( [unknown_keys[index] for index in unknown_keys_idx_new])) for key_index, vector in zip(unknown_keys_index, embedding_vectors_new): embedding_vectors[key_index] = vector embedding_vectors = np.concatenate(embedding_vectors, axis=0) return embedding_vectors.reshape((len(keys), embedding_table_dim))
def _lookup_embeddings_and_slots(self, grads_and_vars): """Look up embedding vectors and slot values form kv store. This function looks up embedding vectors and slot values. It initializes unknown slot if exist. Arguments: grads_and_vars: A list of (gradient, layer name) pairs. Returns: A tuple of (`embedding_values`, `slot_values`). `embedding_values` is a python dictionary of {layer name: `embedding_vectors`} where `embedding_vectors` is a 2D `numpy.ndarray`. `slot_values` is a python dictionary of {layer name: {slot name: `slot_values`}} where `slot_values` is a 2D `numpy.ndarray`. Raises: RuntimeError: If any unknown embedding key exists. """ arr = self._generate_lookup_keys(grads_and_vars) embed_keys, slot_keys, embed_key_index, slot_key_index = arr keys = embed_keys + slot_keys embed_keys_num = len(embed_keys) values, unknown_keys = EmbeddingService.lookup_embedding( keys=keys, embedding_service_endpoint=self._kv_store_endpoint) if unknown_keys: # raise Error if an unknown embedding key exists if unknown_keys[0] < embed_keys_num: raise RuntimeError("Failed to get key %s from kv store." % embed_keys[unknown_keys[0]]) # initialize unknown slots for idx in unknown_keys: key = keys[idx] layer_name = _get_embedding_layer_name_from_key(key) slot_name = _get_slot_name_from_key(key) values[idx] = self._initialize_unknown_slot( layer_name, slot_name) embed_values = _parse_lookup_values(values[:embed_keys_num], embed_key_index) slot_values = _parse_lookup_values(values[embed_keys_num:], slot_key_index) return embed_values, slot_values
def test_lookup_and_update_embedding(self): with tempfile.TemporaryDirectory() as temp_dir: embedding_endpoint = start_redis_instances(temp_dir) # start embedding_service = EmbeddingService(embedding_endpoint) embedding_endpoint = embedding_service._create_redis_cluster() # wait for cluster up-running time.sleep(1) origin_data = np.random.rand(100, 10).astype(np.float32) keys = ["test_%d" % i for i in range(origin_data.shape[0])] EmbeddingService.update_embedding(keys, origin_data, embedding_endpoint) lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding( keys, embedding_endpoint, parse_type=np.float32) self.assertTrue(len(unknown_keys_idx) == 0) output_length = len(keys) lookup_data = np.concatenate(lookup_data, axis=0) lookup_data = lookup_data.reshape((output_length, -1)) self.assertTrue(np.equal(origin_data, lookup_data).all()) # Test set_if_not_exist origin_data_2 = np.random.rand(100, 10).astype(np.float32) self.assertFalse(np.equal(origin_data, origin_data_2).all()) EmbeddingService.update_embedding(keys, origin_data_2, embedding_endpoint, set_if_not_exist=True) lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding( keys, embedding_endpoint, parse_type=np.float32) lookup_data = np.concatenate(lookup_data, axis=0) lookup_data = lookup_data.reshape((output_length, -1)) self.assertTrue(np.equal(origin_data, lookup_data).all()) self.assertFalse(np.equal(origin_data_2, lookup_data).all()) # Test non-exist keys keys_do_not_exist = ["test_no_exist_%d" % i for i in range(10)] lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding( keys_do_not_exist, embedding_endpoint, parse_type=np.float32) self.assertTrue(len(unknown_keys_idx) == 10) self.assertTrue(len(lookup_data) == 10) # Close self.assertTrue(embedding_service.stop_embedding_service())
def main(): args = parse_args() logger = get_logger("master", level=args.log_level.upper()) # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") master_addr = "%s:%d" % (master_ip, args.port) # Start tensorboard service if required if args.tensorboard_log_dir: logger.info( "Starting tensorboard service with log directory %s", args.tensorboard_log_dir, ) tb_service = TensorboardService(args.tensorboard_log_dir, master_ip) tb_service.start() else: tb_service = None # Start task queue logger.debug( "Starting task queue with training data directory %s, " "evaluation data directory %s, " "and prediction data directory %s", args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, ) task_d = _make_task_dispatcher( args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, args.records_per_task, args.num_epochs, ) model_module = load_module( get_module_file_path(args.model_zoo, args.model_def)).__dict__ model_inst = load_model_from_module(args.model_def, model_module, args.model_params) optimizer = model_module[args.optimizer]() if all(( args.training_data_dir, args.evaluation_data_dir, args.evaluation_throttle_secs or args.evaluation_steps, )): job_type = JobType.TRAINING_WITH_EVALUATION elif all(( args.evaluation_data_dir, not args.training_data_dir, not args.prediction_data_dir, )): job_type = JobType.EVALUATION_ONLY elif all(( args.prediction_data_dir, not args.evaluation_data_dir, not args.training_data_dir, )): job_type = JobType.PREDICTION_ONLY else: job_type = JobType.TRAINING_ONLY # Initialize checkpoint service if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION: logger.info("Starting checkpoint service") checkpoint_service = CheckpointService( args.checkpoint_dir, args.checkpoint_steps, args.keep_checkpoint_max, job_type == JobType.TRAINING_WITH_EVALUATION, ) else: checkpoint_service = None # Initialize evaluation service evaluation_service = None if (job_type == JobType.TRAINING_WITH_EVALUATION or job_type == JobType.EVALUATION_ONLY): logger.info( "Starting evaluation service with throttle seconds %d " " and evaluation steps %d", args.evaluation_throttle_secs, args.evaluation_steps, ) evaluation_service = EvaluationService( checkpoint_service, tb_service, task_d, args.evaluation_start_delay_secs, args.evaluation_throttle_secs, args.evaluation_steps, job_type == JobType.EVALUATION_ONLY, ) evaluation_service.start() task_d.set_evaluation_service(evaluation_service) embedding_service_endpoint = None # Search for embedding layers in the model, # if found, initialize embedding service layers = find_layer(model_inst, Embedding) if layers: embedding_service = EmbeddingService() embedding_service_endpoint = embedding_service.start_embedding_service( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, resource_request=args.master_resource_request, resource_limit=args.master_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, ) logger.info("Embedding service start succeeded. The endpoint is %s." % str(embedding_service_endpoint)) # The master service logger.info("Starting master service") server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) master_servicer = MasterServicer( args.grads_to_wait, args.minibatch_size, optimizer, task_d, init_var=model_inst.trainable_variables if model_inst.built else [], checkpoint_filename_for_init=args.checkpoint_filename_for_init, checkpoint_service=checkpoint_service, evaluation_service=evaluation_service, embedding_service_endpoint=embedding_service_endpoint, ) elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server) server.add_insecure_port("[::]:{}".format(args.port)) server.start() logger.info("Server started at port: %d", args.port) worker_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--model_zoo", args.model_zoo, "--master_addr", master_addr, "--log_level", args.log_level, "--dataset_fn", args.dataset_fn, "--loss", args.loss, "--optimizer", args.optimizer, "--eval_metrics_fn", args.eval_metrics_fn, "--model_def", args.model_def, "--job_type", job_type, "--minibatch_size", str(args.minibatch_size), "--embedding_service_endpoint", str(embedding_service_endpoint), ] logger.info(">>> master pod envs argument is %s" % args.envs) env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) worker_manager = WorkerManager( task_d, job_name=args.job_name, image_name=args.worker_image, command=worker_command, args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, ) worker_manager.update_status(WorkerManagerStatus.PENDING) logger.info("Launching %d workers", args.num_workers) worker_manager.start_workers() worker_manager.update_status(WorkerManagerStatus.RUNNING) if tb_service: worker_manager.start_tensorboard_service() try: while True: if task_d.finished(): if worker_manager: worker_manager.update_status(WorkerManagerStatus.FINISHED) if args.output: master_servicer.save_latest_checkpoint(args.output) break time.sleep(30) except KeyboardInterrupt: logger.warning("Server stopping") if evaluation_service: logger.info("Stopping evaluation service") evaluation_service.stop() logger.info("Stopping RPC server") server.stop(0) # Keep TensorBoard running when all the tasks are finished if tb_service: logger.info( "All tasks finished. Keeping TensorBoard service running...") while True: if tb_service.is_active(): time.sleep(10) else: logger.warning("Unable to keep TensorBoard running. " "It has already terminated") break logger.info("Master stopped")