Example #1
0
def transfer_weights(
    model: Model,
    new_model: Model,
    optimizer: tf.keras.optimizers.Optimizer,
    new_optimizer: tf.keras.optimizers.Optimizer,
    ignore_weights: list = None,
) -> None:
    if type(model) is not type(new_model):
        raise ValueError(
            "Transferring weights to another model type is not supported")
    if ignore_weights is None:
        ignore_weights = list()
    ignore_weights_ref = set(weight.experimental_ref()
                             for weight in ignore_weights)
    weights = model.weights
    new_weights = new_model.weights
    for weight, new_weight in zip(weights, new_weights):
        if new_weight.experimental_ref() not in ignore_weights_ref:
            new_weight.assign(weight)
            for slot_name in new_optimizer.get_slot_names():
                if slot_name not in optimizer.get_slot_names():
                    continue
                new_slot = new_optimizer.get_slot(new_weight, slot_name)
                slot = optimizer.get_slot(weight, slot_name)
                new_slot.assign(slot)
Example #2
0
def update_variable_and_slots(
    update: VariableUpdate,
    mapping: List[int],
    ref_optimizer: tf.keras.optimizers.Optimizer,
    new_optimizer: tf.keras.optimizers.Optimizer,
) -> List[tf.Variable]:
    """Update a vocabulary variable and its associated optimizer slots (if any)."""
    variables = [update_variable(update, mapping)]
    ref_slot_names = ref_optimizer.get_slot_names()
    new_slot_names = new_optimizer.get_slot_names()
    for slot_name in ref_slot_names:
        if slot_name not in new_slot_names:
            continue
        ref_slot = ref_optimizer.get_slot(update.ref_variable, slot_name)
        new_slot = new_optimizer.get_slot(update.new_variable, slot_name)
        slot_update = VariableUpdate(ref_slot, new_slot, update.vocab_axis)
        variables.append(update_variable(slot_update, mapping))
    return variables