def transfer_weights( model: Model, new_model: Model, optimizer: tf.keras.optimizers.Optimizer, new_optimizer: tf.keras.optimizers.Optimizer, ignore_weights: list = None, ) -> None: if type(model) is not type(new_model): raise ValueError( "Transferring weights to another model type is not supported") if ignore_weights is None: ignore_weights = list() ignore_weights_ref = set(weight.experimental_ref() for weight in ignore_weights) weights = model.weights new_weights = new_model.weights for weight, new_weight in zip(weights, new_weights): if new_weight.experimental_ref() not in ignore_weights_ref: new_weight.assign(weight) for slot_name in new_optimizer.get_slot_names(): if slot_name not in optimizer.get_slot_names(): continue new_slot = new_optimizer.get_slot(new_weight, slot_name) slot = optimizer.get_slot(weight, slot_name) new_slot.assign(slot)
def update_variable_and_slots( update: VariableUpdate, mapping: List[int], ref_optimizer: tf.keras.optimizers.Optimizer, new_optimizer: tf.keras.optimizers.Optimizer, ) -> List[tf.Variable]: """Update a vocabulary variable and its associated optimizer slots (if any).""" variables = [update_variable(update, mapping)] ref_slot_names = ref_optimizer.get_slot_names() new_slot_names = new_optimizer.get_slot_names() for slot_name in ref_slot_names: if slot_name not in new_slot_names: continue ref_slot = ref_optimizer.get_slot(update.ref_variable, slot_name) new_slot = new_optimizer.get_slot(update.new_variable, slot_name) slot_update = VariableUpdate(ref_slot, new_slot, update.vocab_axis) variables.append(update_variable(slot_update, mapping)) return variables