Example #1
0
        def soft_updates(local_model: tf.keras.Model,
                         target_model: tf.keras.Model) -> np.ndarray:
            local_weights = np.array(local_model.get_weights())
            target_weights = np.array(target_model.get_weights())

            assert len(local_weights) == len(target_weights)
            new_weights = TAU * local_weights + (1 - TAU) * target_weights
            return new_weights
Example #2
0
def make_z_as_input(generator: tf.keras.Model, model_config: dict,
                    speech_config: dict):
    generator_v2 = create_generator_v2(
        g_enc_depths=model_config["g_enc_depths"],
        window_size=speech_config["window_size"],
        kwidth=model_config["kwidth"],
        ratio=model_config["ratio"])
    generator_v2.set_weights(generator.get_weights())
    return generator_v2
Example #3
0
def copy_original_weights(original_model: tf.keras.Model,
                          quantized_model: tf.keras.Model):
    """Helper function that copy the original model weights to quantized model."""
    original_weight_value = original_model.get_weights()
    weight_values = quantized_model.get_weights()

    original_idx = 0
    for idx, weight in enumerate(quantized_model.weights):
        if not is_quantization_weight_name(weight.name):
            if original_idx >= len(original_weight_value):
                raise ValueError('Not enought original model weights.')
            weight_values[idx] = original_weight_value[original_idx]
            original_idx = original_idx + 1

    if original_idx < len(original_weight_value):
        raise ValueError('Not enought quantized model weights.')

    quantized_model.set_weights(weight_values)
Example #4
0
    def __init__(self, model: tf.keras.Model):
        """
        Args:
            model: Model to be sent
        """

        self.config = model.to_json()
        # Makes a copy of the weights to ensure they are not memoryview objects
        self.weights = [np.array(v) for v in model.get_weights()]
Example #5
0
def get_num_weights(model: tf.keras.Model) -> int:
    """Utility function to determine the number of weights in a keras model.

    Arguments:
        model {tf.keras.Model} -- the keras model

    Returns:
        int -- the number of weights
    """

    weights_shape = map(tf.shape, model.get_weights())
    return sum(map(tf.reduce_prod, weights_shape))
Example #6
0
def polyak_averaging(model: tf.keras.Model, old_weights: list, alpha=0.99):
    """Source: Deep Learning Book (section 8.7.3)
        - the original implementation is: `w = alpha * w_old + (1.0 - alpha) * w_new`,
          here we use `w = alpha * w_new + (1.0 - alpha) * w_old` because it performs better for RL
    """
    new_weights = model.get_weights()
    weights = []

    for w_old, w_new in zip(old_weights, new_weights):
        w = alpha * w_new + (1.0 - alpha) * w_old
        weights.append(w)

    model.set_weights(weights)
Example #7
0
def transfer_weights(source_model: tf.keras.Model,
                     target_model: tf.keras.Model):
    """
    Function to transfer weights from trained model to other one
    Args:
        source_model: trained `tf.keras.Model`
        target_model: target `tf.keras.Model`

    Returns:
        trained target_model
    """
    target_model.set_weights(source_model.get_weights())
    return target_model
def adv_perturbation_closed_form(model: tf.keras.Model,
                                 x: tf.Tensor,
                                 y: tf.Tensor,
                                 eps: float = 0.01) -> tf.Tensor:
    """
    Uses the closed form of projected descent for single-layer networks.

    :param model: keras model with a single layer of weights
    :param x: input tensor
    :param y: Tensor with true labels
    :param eps: amount to scale perturbation by
    :return: a tensor with adversarial samples
    """
    weights = model.get_weights()[0]
    y_batch_plus_minus_one = tf.where(tf.equal(y, 1.0), 1.0, -1.0)
    perturbation = -eps*y_batch_plus_minus_one[:, None] @ tf.transpose(
        tf.sign(weights))
    return x + tf.squeeze(perturbation)
Example #9
0
def lr_finder(
    model: tf.keras.Model,
    optimizer: tf.keras.optimizers.Optimizer,
    loss_fn: tf.keras.losses.Loss,
    dataset,
    learn_rates: LrGenerator,
    losses: SmoothedLoss,
) -> Lr:
    # To run the lr finder directly before training, we need to reset
    # the initial weights. Here, I do it with building, storing and resetting.
    model.build(dataset.element_spec[0].shape)
    weights = model.get_weights()
    for lr, (source, target) in zip(learn_rates(), dataset):
        tf.keras.backend.set_value(optimizer.lr, lr)
        loss = train_step(model, optimizer, loss_fn, source, target).numpy()
        if losses.no_progress:
            break
        losses.update(loss)

    model.set_weights(weights)
    return Lr(learn_rates, losses)
Example #10
0
def soft_update(target: tf.keras.Model, source: tf.keras.Model, tau):
    new_weights = []
    for target_weights, source_weights in zip(target.get_weights(),
                                              source.get_weights()):
        new_weights.append(target_weights * (1. - tau) + source_weights * tau)
    target.set_weights(new_weights)
Example #11
0
def train_model(model: tf.keras.Model, model_for_validation: tf.keras.Model,
                list_of_batches: list, data_validate, current_basepath: str):
    """
    trains the model using a custom training loop and early stopping
    :param model: the model to train
    :param model_for_validation: an identical model except for the input shape for manual validation
    :param list_of_batches: training set, as list of equal-length batches, each containing only samples of equal
    terminal count
    :param data_validate: validation set
    :param current_basepath: path to directory to save plots, logs, and teh trained model in
    :return: nothing, modifies filesystem
    """
    # create plot save folder
    plot_basepath = os.path.abspath(os.path.join(current_basepath, "plots/"))
    os.makedirs(plot_basepath)
    # history of training loss
    loss_history_train = []
    # history of validation loss
    loss_history_validate = []
    # history of "reference training loss"
    loss_history_train_reference = []
    # init individual loss histories
    individual_loss_histories_validate, final_individual_loss_histories, \
        loss_corrected_hpwl = init_loss_histories(data_validate)
    # init early stopping variables
    early_stopping_counter = 0
    best_val_loss = math.inf
    early_stopping_checkpoint = None
    # model epochs
    for epoch_count in range(epochs):
        gc.collect()
        print("epoch #", epoch_count)
        # shuffle the individual batches (externally)
        np.random.shuffle(list_of_batches)
        #
        loss_history_train.append(0)
        # in each epoch train on all the training data, one batch at a time
        for batch_counter in range(len(list_of_batches)):
            print(
                "epoch(" + str(epoch_count) + "/" + str(epochs - 1) +
                ") batch #", batch_counter)
            # shuffle the specific batch (internally)
            np.random.shuffle(list_of_batches[batch_counter])
            # prepare the data as input to the model
            x_train = np.asarray([
                data.coordinate_pairs
                for data in list_of_batches[batch_counter]
            ])
            x_train = x_train.reshape(
                batch_size,
                len(list_of_batches[batch_counter][0].coordinate_pairs), 2)
            y_train = np.asarray(
                [data.true_cost for data in list_of_batches[batch_counter]])
            # train on one batch
            loss_history_train_tmp = model.fit(x_train, y_train, epochs=1)
            # mean squared error can be summed up and later devided by the number of batches, as long as each batch
            # is of equal size
            loss_history_train[epoch_count] += loss_history_train_tmp.history[
                'loss'][0]
        # finalize mean squared error of train run
        loss_history_train[epoch_count] /= len(list_of_batches)
        # copy weights to the validation model with batch size of 1
        model_for_validation.set_weights(model.get_weights())
        # validate
        produced_results = []
        # reset to 0 (might have been used above)
        squared_loss_sum = 0
        # prepare individual loss records
        for individual_loss_list_index in range(
                len(individual_loss_histories_validate)):
            if not individual_loss_histories_validate[
                    individual_loss_list_index][1] == 0:
                individual_loss_histories_validate[individual_loss_list_index][
                    0].append(0)
        # use the whole validation set
        for data in data_validate:
            result = model_for_validation.predict(
                np.asarray(data.coordinate_pairs).reshape(
                    1, len(data.coordinate_pairs), 2))
            # log produced results for plotting (only for first, middle and last iteration because of visual clutter)
            if ((epoch_count == 0) or (epoch_count == epochs - 1)
                    or (epoch_count == int(epochs / 2))
                    or (early_stopping_counter > early_stopping_patience)):
                produced_results.append(result)
            # compute validation loss
            if loss_function == 'mean_squared_error':
                squared_error = (result - data.true_cost) * (result -
                                                             data.true_cost)
                if print_high_error_samples and squared_error > 5:
                    print("large error detected:", squared_error)
                    print("expected value:", data.true_cost)
                    print("computed value:", result)
                    print("sequence length:", len(data.coordinate_pairs))
                    print("sequence:", data.coordinate_pairs)
                squared_loss_sum += squared_error
                individual_loss_histories_validate[
                    len(data.coordinate_pairs) -
                    2][0][epoch_count] += squared_error
            else:
                print("ERROR: loss function not implemented")
                sys.exit("failed to compute loss function: not implemented.")
        # finalize loss computation
        current_val_loss = float(squared_loss_sum / len(data_validate))
        # perform early stopping routine
        if current_val_loss < best_val_loss:
            print("loss improved from {} to {}".format(best_val_loss,
                                                       current_val_loss))
            # model improved, reset early stopping counter
            early_stopping_counter = 0
            # save current model weights as checkpoint
            early_stopping_checkpoint = model.get_weights()
            # update bsf val loss
            best_val_loss = current_val_loss
        else:
            print("loss did not improved from {} (currently {})".format(
                best_val_loss, current_val_loss))
            print("\tpatience left: {}".format(early_stopping_patience -
                                               early_stopping_counter))
            # model did not improve, increase counter
            early_stopping_counter += 1
        # log the mean loss for this epoch
        loss_history_validate.append(current_val_loss)
        # finalize individual mean loss values for this epoch
        for individual_loss_list_index in range(
                len(individual_loss_histories_validate)):
            if not individual_loss_histories_validate[
                    individual_loss_list_index][1] == 0:
                individual_loss_histories_validate[individual_loss_list_index][0][epoch_count] /= \
                    individual_loss_histories_validate[individual_loss_list_index][1]
        # optionally compute reference train error (= error produced on training set with same Network state used for
        # calculating validation error
        squared_loss_sum = 0
        if (compute_and_plot_reference_train_error
                or (epoch_count == epochs - 1)
                or (early_stopping_counter > early_stopping_patience)):
            for batch in list_of_batches:
                for data in batch:
                    result = model_for_validation.predict(
                        np.asarray(data.coordinate_pairs).reshape(
                            1, len(data.coordinate_pairs), 2))
                    # compute validation loss
                    if loss_function == 'mean_squared_error':
                        squared_error = (result - data.true_cost) * (
                            result - data.true_cost)
                        squared_loss_sum += squared_error
                    else:
                        print("ERROR: loss function not implemented")
                        sys.exit(
                            "failed to compute loss function: not implemented."
                        )
            # log the mean loss for this epoch
            loss_history_train_reference.append(
                float(squared_loss_sum / (len(list_of_batches) * batch_size)))

        # plot expected and produced results (only for first, middle and last iteration because of visual clutter)
        if ((epoch_count == 0) or (epoch_count == epochs - 1)
                or (epoch_count == int(epochs / 2))
                or (early_stopping_counter > early_stopping_patience)):
            plt.title("predictions against true values, epoch " +
                      str(epoch_count))
            plt.scatter(range(validation_plotting_point_count) if
                        validation_plotting_point_count < len(produced_results)
                        else range(len(produced_results)),
                        produced_results[:validation_plotting_point_count] if
                        validation_plotting_point_count < len(produced_results)
                        else produced_results,
                        c='r',
                        label='predicted_cost')
            plt.scatter(
                range(validation_plotting_point_count)
                if validation_plotting_point_count < len(data_validate) else
                range(len(data_validate)),
                [data.true_cost
                 for data in data_validate][:validation_plotting_point_count]
                if validation_plotting_point_count < len(data_validate) else
                [data.true_cost for data in data_validate],
                c='g',
                label='true_cost')
            plt.scatter(
                range(validation_plotting_point_count)
                if validation_plotting_point_count < len(data_validate) else
                range(len(data_validate)),
                [data.corrected_hpwl
                 for data in data_validate][:validation_plotting_point_count]
                if validation_plotting_point_count < len(data_validate) else
                [data.corrected_hpwl for data in data_validate],
                c='b',
                label='corrected_hpwl')
            plt.legend(loc='best')
            plt.savefig(
                plot_basepath +
                ("limited_data_" if data_limit_flag else "full_data_") +
                "scatterplot_" +
                ("initial" if (epoch_count == 0) else
                 ("final" if
                  (epoch_count == epochs - 1) else "midway")) + ".png",
                bbox_inches='tight')
            plt.show()

        # plot combined validation loss over epochs
        plt.title("loss plot for whole validation data set")
        plt.xlabel("epoch")
        plt.ylabel(loss_function)
        plt.yscale('log')
        plt.plot(range(epoch_count + 1),
                 loss_history_validate,
                 'bo-',
                 label='validation loss')
        plt.plot(range(epoch_count + 1),
                 loss_history_train,
                 'gx-',
                 label='training loss')
        if (compute_and_plot_reference_train_error
                or (epoch_count == epochs - 1)
                or (early_stopping_counter > early_stopping_patience)):
            if not compute_and_plot_reference_train_error:
                reference_loss_last_iter = loss_history_train_reference[0]
                loss_history_train_reference = []
                for epoch_count_tmp in range(epoch_count):
                    loss_history_train_reference.append(0)
                loss_history_train_reference.append(reference_loss_last_iter)
            plt.plot(range(epoch_count + 1),
                     loss_history_train_reference,
                     'r*-',
                     label='reference training loss')
        plt.axhline(y=loss_corrected_hpwl,
                    color='y',
                    linestyle='-',
                    label='corrected_hpwl loss')
        plt.legend(loc='best')
        if epoch_count == epochs - 1 or (early_stopping_counter >
                                         early_stopping_patience):
            plt.savefig(
                plot_basepath +
                ("limited_data_" if data_limit_flag else "full_data_") +
                "loss.png",
                bbox_inches='tight')
        plt.show()

        # plot individual loss histories in one combined plot for easy qualitative comparison at the end
        # if epoch_count == epochs - 1:
        if (print_and_plot_stats_every_epoch or (epoch_count == epochs - 1)
                or (early_stopping_counter > early_stopping_patience)):
            # extract only non-empty histories and assign subplot names
            column_names = []
            filtered_individual_loss_histories = []
            filtered_individual_original_loss_histories = []
            for individual_loss_list_index in range(
                    len(individual_loss_histories_validate)):
                if not individual_loss_histories_validate[
                        individual_loss_list_index][1] == 0:
                    column_names.append(str(individual_loss_list_index + 2))
                    filtered_individual_loss_histories.append(
                        individual_loss_histories_validate[
                            individual_loss_list_index][0])
                    individual_original_loss_tmp = []
                    for count in range(
                            len(individual_loss_histories_validate[
                                individual_loss_list_index][0])):
                        individual_original_loss_tmp.append(
                            individual_loss_histories_validate[
                                individual_loss_list_index][2])
                    filtered_individual_original_loss_histories.append(
                        individual_original_loss_tmp)

            # prepare and print the combined plot
            # prepare data
            df_individual_loss_histories = pd.DataFrame(
                np.asarray(filtered_individual_loss_histories).reshape(
                    len(filtered_individual_loss_histories),
                    epoch_count + 1).T,
                columns=np.asarray(column_names))
            df_individual_original_loss_histories = pd.DataFrame(
                np.asarray(
                    filtered_individual_original_loss_histories).reshape(
                        len(filtered_individual_original_loss_histories),
                        epoch_count + 1).T,
                columns=np.asarray(column_names))

            print("corrected_hpwl loss history:")
            print(df_individual_original_loss_histories)
            # create plot
            fig1, ax1 = plt.subplots(figsize=(20, 15))
            # assign different colors to different histories
            colors = sns.cubehelix_palette(
                len(filtered_individual_loss_histories), rot=0.9)
            df_individual_loss_histories.plot(color=colors,
                                              ax=ax1,
                                              linestyle='-')
            df_individual_original_loss_histories.plot(color=colors,
                                                       ax=ax1,
                                                       linestyle='--')
            # generate the legend
            plt.semilogy()
            plt.legend(ncol=4, loc='best')
            # specify title and axis labels
            plt.title(
                "individual loss histories for validation sequences with different lengths"
            )
            plt.xlabel("epoch")
            plt.ylabel(loss_function)
            # if it is the final plot save it (also save a plot just for the initial phase)
            if (epoch_count == epochs - 1 or epoch_count == int(epochs / 3)
                    or (early_stopping_counter > early_stopping_patience)):
                final_individual_loss_histories = df_individual_loss_histories
                plt.savefig(
                    plot_basepath +
                    ("limited_data_" if data_limit_flag else "full_data_") +
                    "loss_individual_" +
                    ("final" if (epoch_count == epochs - 1) or
                     (early_stopping_counter > early_stopping_patience) else
                     "initial") + ".png",
                    bbox_inches='tight')
            # finally make the plot visible
            plt.show()

            print("validation loss history:")
            print(df_individual_loss_histories)
            if ((epoch_count == epochs - 1)
                    or (early_stopping_counter > early_stopping_patience)):
                df_individual_loss_histories.to_csv(
                    os.path.abspath(
                        os.path.join(current_basepath,
                                     "individual_loss_histories.csv")))
            print("loss history:")
            print(loss_history_validate)
        # check if training finished due to early stopping
        if early_stopping_counter > early_stopping_patience:
            # patience exhausted, abort training (best model will be saved from checkpointed weights saved)
            break

    gc.collect()

    model_for_validation.set_weights(early_stopping_checkpoint)
    model_for_validation.save(os.path.abspath(
        os.path.join(current_basepath, "model")),
                              include_optimizer=False)

    with open(
            plot_basepath +
        ("limited_data_" if data_limit_flag else "full_data_") +
            "hyperparameters.txt", "w") as text_file:
        print("batch size: {}".format(batch_size), file=text_file)
        print("number of epochs: {}".format(epochs), file=text_file)
        print("loss function: {}".format(loss_function), file=text_file)
        print("optimizer: {}".format(optimizer_choice), file=text_file)
        print("numpy random number generator seed: {}".format(np_random_seed),
              file=text_file)
        print("data limit factor: {}".format(
            "none" if not data_limit_flag else data_limit),
              file=text_file)
    with open(
            plot_basepath +
        ("limited_data_" if data_limit_flag else "full_data_") +
            "loss_history.txt", "w") as text_file:
        print("combined validation loss history:", file=text_file)
        print("", file=text_file)
        print(loss_history_validate, file=text_file)
        print("", file=text_file)
        print("individual validation loss histories:", file=text_file)
        print("", file=text_file)
        print(final_individual_loss_histories, file=text_file)
Example #12
0
 def softsync(target: tf.keras.Model, online: tf.keras.Model):
     target_weights = target.get_weights()
     online_weights = online.get_weights()
     new_weights = tf.nest.map_structure(
         weighted_sum, target_weights, online_weights)
     target.set_weights(new_weights)
Example #13
0
 def update_target_network(net: tf.keras.Model,
                           target_net: tf.keras.Model):
     net_weights = np.array(net.get_weights())
     target_net_weights = np.array(target_net.get_weights())
     new_weights = tau * net_weights + (1.0 - tau) * target_net_weights
     target_net.set_weights(new_weights)
Example #14
0
def _train_model(model: tf.keras.Model, database: Dict[str, float], num_epochs: int, test_set: Optional[List[str]],
                 batch_size: int = 32, validation_split: float = 0.1, bootstrap: bool = False,
                 random_state: int = 1, learning_rate: float = 1e-3, patience: int = None,
                 timeout: float = None) -> Union[Tuple[List, dict], Tuple[List, dict, List[float]]]:
    """Train a model

    Args:
        model: Model to be trained
        database: Training dataset of molecule mapped to a property
        test_set: Hold-out set. If provided, this function will return predictions on this set
        num_epochs: Maximum number of epochs to run
        batch_size: Number of molecules per training batch
        validation_split: Fraction of molecules used for the training/validation split
        bootstrap: Whether to perform a bootstrap sample of the dataset
        random_state: Seed to the random number generator. Ensures entries do not move between train
            and validation set as the database becomes larger
        learning_rate: Learning rate for the Adam optimizer
        patience: Number of epochs without improvement before terminating training.
        timeout: Maximum training time in seconds
    Returns:
        model: Updated weights
        history: Training history
    """
    # Compile the model with a new optimizer
    #  We find that it is best to reset the optimizer before updating
    model.compile(tf.keras.optimizers.Adam(lr=learning_rate), 'mean_squared_error')

    # Separate the database into molecules and properties
    smiles, y = zip(*database.items())
    smiles = np.array(smiles)
    y = np.array(y)

    # Make the training and validation splits
    rng = np.random.RandomState(random_state)
    train_split = rng.rand(len(smiles)) > validation_split
    train_X = smiles[train_split]
    train_y = y[train_split]
    valid_X = smiles[~train_split]
    valid_y = y[~train_split]

    # Perform a bootstrap sample of the training data
    if bootstrap:
        sample = rng.choice(len(train_X), size=(len(train_X),), replace=True)
        train_X = train_X[sample]
        train_y = train_y[sample]

    # Make the training data loaders
    train_loader = GraphLoader(train_X, train_y, batch_size=batch_size, shuffle=True)
    val_loader = GraphLoader(valid_X, valid_y, batch_size=batch_size, shuffle=False)

    # Make the callbacks
    final_learn_rate = 1e-6
    init_learn_rate = learning_rate
    decay_rate = (final_learn_rate / init_learn_rate) ** (1. / (num_epochs - 1))

    def lr_schedule(epoch, lr):
        return lr * decay_rate

    if patience is None:
        patience = num_epochs // 8

    early_stopping = cb.EarlyStopping(patience=patience, restore_best_weights=True)
    my_callbacks = [
        LRLogger(),
        EpochTimeLogger(),
        cb.LearningRateScheduler(lr_schedule),
        early_stopping,
        cb.TerminateOnNaN(),
        train_loader  # So the shuffling gets called
    ]
    if timeout is not None:
        my_callbacks += [
            TimeLimitCallback(timeout)
        ]

    # Run the desired number of epochs
    history = model.fit(train_loader, epochs=num_epochs, validation_data=val_loader,
                        verbose=False, shuffle=False, callbacks=my_callbacks)

    # If a timeout is used, make sure we are using the best weights
    #  The training may have exited without storing the best weights
    if timeout is not None:
        model.set_weights(early_stopping.best_weights)

    # Check if there is a NaN loss
    if np.isnan(history.history['loss']).any():
        raise ValueError('Training failed due to a NaN loss.')

    # If provided, evaluate model on test set
    test_pred = None
    if test_set is not None:
        test_pred = evaluate_mpnn([model], test_set, batch_size, cache=False)

    # Convert weights to numpy arrays (avoids mmap issues)
    weights = []
    for v in model.get_weights():
        v = np.array(v)
        if np.isnan(v).any():
            raise ValueError('Found some NaN weights.')
        weights.append(v)

    # Once we are finished training call "clear_session"
    tf.keras.backend.clear_session()
    if test_pred is None:
        return weights, history.history
    else:
        return weights, history.history, test_pred[:, 0].tolist()
Example #15
0
    def warmstart(self, target_model: tf.keras.Model, custom_objects=None):
        if not self._params.model:
            logger.debug("No warm start model provided")
            return

        # Names that will be ignored in both the loaded and target model (no real weights)
        names_to_ignore = {"print_limit:0", "count:0"}

        # 1. Load as saved model. if successful -> 3, if failed -> 2
        # 2. If OSError (-> no saved model), load weights directly fron checkpoint -> 5
        # 3. Load weights and assign (only works if the model is identical), if failed -> 4
        # 4. Load weights by name -> 5
        # 5. Apply renaming rules and match weights
        abs_model_path = os.path.abspath(os.path.expanduser(
            self._params.model))
        try:
            # 1. Load model
            src_model = tf.keras.models.load_model(
                abs_model_path, compile=False, custom_objects=custom_objects)
        except OSError:
            # 2. load as checkpoint, then go to 5.
            logger.debug(
                f"Could not load '{abs_model_path}' as saved model. Attempting to load as a checkpoint."
            )
            ckpt = tf.train.load_checkpoint(abs_model_path)
            name_shapes = ckpt.get_variable_to_shape_map()
            model_to_load_var_names = name_shapes.keys()

            def rename_ckpt_var_name(name: str):
                name = name.replace("/.ATTRIBUTES/VARIABLE_VALUE", "")
                return name

            names = self._apply_renamings(model_to_load_var_names,
                                          self._params.rename)
            if self._params.add_suffix:
                names = [name + self._params.add_suffix for name in names]
            weights_ckpt = {
                rename_ckpt_var_name(pp_name): ckpt.get_tensor(name)
                for pp_name, name in zip(names, model_to_load_var_names)
            }
            all_loaded_weights = weights_ckpt
            trainable_name_to_loaded_var_name = {k: k for k in names}
        else:
            # 3. apply weights directly
            logger.info("Source model successfully loaded for warmstart.")
            skip_direct_apply = self._params.include or self._params.exclude
            if not skip_direct_apply:
                try:
                    self.apply_weights(target_model, [
                        np.asarray(w) for w in src_model.weights
                        if w.name not in names_to_ignore
                    ])
                except Exception as e:
                    # 4. Load and rename weights
                    logger.exception(e)
                    logger.warning(
                        "Weights could not be copied directly. Retrying application of renaming rules to"
                        "match variable names.")
                else:
                    # successful, nothing to do
                    return

            loaded_var_names = self._apply_renamings(
                [w.name for w in src_model.weights], self._params.rename)
            model_to_load_var_names = [w.name for w in src_model.weights]
            loaded_weights = list(
                zip(loaded_var_names, src_model.weights,
                    src_model.get_weights()))
            all_loaded_weights = {
                name: weight
                for name, var, weight in loaded_weights
                if name not in names_to_ignore
            }
            trainable_name_to_loaded_var_name = {
                k: v
                for k, v in zip(
                    self._apply_renamings(all_loaded_weights.keys(),
                                          self._params.rename_targets),
                    all_loaded_weights.keys(),
                )
            }

        # 5. Apply names with renaming rules
        target_var_names = self._apply_renamings(
            [w.name for w in target_model.weights],
            self._params.rename_targets)
        target_weights = list(
            zip(target_var_names, target_model.weights,
                target_model.get_weights()))
        if len(set(name for name, var, weight in target_weights)) != len(
                target_weights):
            logger.critical(
                "Non unique names detected in model weight names. You can ignore this warning but the "
                "model will not be initialized correctly!")
        all_trainable_target_weights = {
            name: weight
            for name, var, weight in target_weights
            if name not in names_to_ignore
        }
        # all_trainable_target_weights = {name: weight for name, var, weight in target_weights if var.trainable}
        trainable_name_to_target_var_name = {
            k: v
            for k, v in zip(
                self._apply_renamings(all_trainable_target_weights.keys(),
                                      self._params.rename_targets),
                all_trainable_target_weights.keys(),
            )
        }
        target_var_name_to_trainable_name = {
            v: k
            for k, v in trainable_name_to_target_var_name.items()
        }

        # Filter the params and validate
        names_target = set(trainable_name_to_target_var_name.keys())
        names_loaded = set(trainable_name_to_loaded_var_name.keys())
        if self._params.exclude or self._params.include:
            names_to_load = names_loaded
            if self._params.include:
                inc = re.compile(self._params.include)
                names_to_load = [
                    name for name in names_to_load if inc.fullmatch(name)
                ]

            if self._params.exclude:
                exc = re.compile(self._params.exclude)
                names_to_load = [
                    name for name in names_to_load if not exc.fullmatch(name)
                ]

            if len(names_target.intersection(names_to_load)) == 0:
                raise NameError(
                    f"Not a weight could be matched.\nLoaded: {names_to_load}\nTarget: {names_target}"
                )
        elif self._params.allow_partial:
            names_to_load = names_target.intersection(names_loaded)
        else:
            names_to_load = names_target
            diff_target = names_loaded.difference(names_target)
            diff_load = names_target.difference(names_loaded)
            if len(diff_target) > 0 or len(diff_load) > 0:
                raise NameError(
                    f"Not all weights could be matched:\nTargets '{diff_target}'\nLoaded: '{diff_load}'. "
                    f"\nUse allow_partial to allow partial loading")

        new_weights = []
        warm_weights_names = []
        cold_weights_names = []
        non_trainable_weights_names = []
        for weight_idx, name in enumerate(target_var_names):
            # access original weight via index because names might not be unique (e.g. in metrics)
            trainable_name = target_var_name_to_trainable_name.get(
                name, None)  # None == not existing
            if trainable_name in names_to_load:
                warm_weights_names.append(name)
                new_weights.append(
                    self._transform_weight(
                        all_loaded_weights[
                            trainable_name_to_loaded_var_name[trainable_name]],
                        trainable_name, name))
            else:
                if name in all_trainable_target_weights:
                    cold_weights_names.append(name)
                else:
                    non_trainable_weights_names.append(name)
                new_weights.append(
                    target_weights[weight_idx][2])  # set to original weight
        not_loaded_weights = [
            name for name in names_loaded if name not in warm_weights_names
        ]
        newline = "\n\t"
        logger.info(
            newline.join(["model-to-load weights:"] + model_to_load_var_names))
        logger.info(
            newline.join(["renamed unmached weights:"] +
                         [str(x) for x in not_loaded_weights]))
        logger.info(
            newline.join(["Warm weights:"] +
                         [str(x) for x in warm_weights_names]))
        logger.info(
            newline.join(["Cold weights:"] +
                         [str(x) for x in cold_weights_names]))
        logger.info(
            f"There are {len(non_trainable_weights_names)} non trainable weights."
        )
        if len(names_to_load) == 0:
            raise ValueError(
                "No warmstart weight could be matched! Set TFAIP_LOG_LEVEL=INFO for more information."
            )

        self.apply_weights(target_model, new_weights)
Example #16
0
def hard_update(target: tf.keras.Model, source: tf.keras.Model):
    target.set_weights(source.get_weights())
Example #17
0
 def set_model_weights(self, model: tf.keras.Model):
     self.local_model.set_weights(model.get_weights())