def _run_model_call_before_training(self, features): """Call `self._model.call` before training for two things: * Create variables and report to ps if not created. * Check whether there is an embedding layer that is called more than once during one forward-pass. """ if self._embedding_layers: with tf.GradientTape() as tape: self._set_tape_for_embedding(tape) _ = self._model.call(features) else: _ = self._model.call(features) self._non_embed_vars = {} for var in get_non_embedding_trainable_vars(self._model, self._embedding_layers): self._non_embed_vars[var.name] = var self._var_created = True if self._use_multi_ps: self.init_ps_var_partition() if self._need_embedding_layer_check: self._train_eagerly = False for layer in self._embedding_layers: if len(layer.embedding_and_ids) > 1: self._train_eagerly = True self.logger.warning( "ElasticDL embedding layer %s is called more than " "once, this will make the training process unable " "to accelerate with tf.function." % (layer.name)) self._need_embedding_layer_check = False self._reset_embedding()
def _train_edl_embedding_with_optimizer_wrapper(model, opt_keras, X, Y, loss_fn, embed_dims, random_seed): """Train model with optimizer wrapper.""" tf.random.set_seed(random_seed) optimizer = OptimizerWrapper(opt_keras, None, embed_dims) # initialization process related to embedding layer and optimizer wrapper embed_layers = find_layer(model, Embedding) # training process for train_iter, (features, labels) in enumerate(zip(X, Y)): with tf.GradientTape() as tape: for layer in embed_layers: layer.set_tape(tape) outputs = model.call(features) loss = loss_fn(outputs, labels) # Need to get non-embedding variables inside for loop because model # creates variables after the first time `model.call` is called if not train_iter: non_embed_vars = get_non_embedding_trainable_vars( model, embed_layers) embed_items = [] for layer in embed_layers: embed_items.extend([(bet, layer.name, ids) for bet, ids in layer.embedding_and_ids]) grads = tape.gradient( loss, non_embed_vars + [var for var, _, _ in embed_items]) # TODO: do not need to merge gradient from the same embedding layer # after `optimizer_wrapper` support grads_and_vars with duplicated # layer name non_embed_vars_n = len(non_embed_vars) non_embed_grads = grads[:non_embed_vars_n] embed_grads_dict = {} for (_, layer_name, ids), grad in zip(embed_items, grads[non_embed_vars_n:]): if layer_name in embed_grads_dict: merged_grads = embed_grads_dict[layer_name] embed_grads_dict[layer_name] = tf.IndexedSlices( tf.concat([merged_grads.values, grad.values], axis=0), tf.concat([merged_grads.indices, ids], axis=0), ) else: embed_grads_dict[layer_name] = tf.IndexedSlices( grad.values, ids) optimizer.apply_gradients( list(zip(non_embed_grads, non_embed_vars)) + [(grad, layer_name) for layer_name, grad in embed_grads_dict.items()]) for layer in embed_layers: layer.reset()
def set_model(self, model_inst): """Set model instance to worker.""" self._model = model_inst self._train_eagerly = False self._init_embedding_layer() self._var_created = self._model.built self._non_embed_vars = [] if self._var_created: self._non_embed_vars = get_non_embedding_trainable_vars( self._model, self._embedding_layers)
def set_model(self, model_inst): """Set model instance to worker.""" self._model = model_inst self._train_eagerly = False self._init_embeddings() self._var_created = self._model.built self._non_embed_vars = {} if self._var_created: for var in get_non_embedding_trainable_vars( self._model, self._embedding_layers): self._non_embed_vars[var.name] = var if self._use_multi_ps: self.init_ps_var_partition()