def soft_updates(local_model: tf.keras.Model, target_model: tf.keras.Model) -> np.ndarray: local_weights = np.array(local_model.get_weights()) target_weights = np.array(target_model.get_weights()) assert len(local_weights) == len(target_weights) new_weights = TAU * local_weights + (1 - TAU) * target_weights return new_weights
def make_z_as_input(generator: tf.keras.Model, model_config: dict, speech_config: dict): generator_v2 = create_generator_v2( g_enc_depths=model_config["g_enc_depths"], window_size=speech_config["window_size"], kwidth=model_config["kwidth"], ratio=model_config["ratio"]) generator_v2.set_weights(generator.get_weights()) return generator_v2
def copy_original_weights(original_model: tf.keras.Model, quantized_model: tf.keras.Model): """Helper function that copy the original model weights to quantized model.""" original_weight_value = original_model.get_weights() weight_values = quantized_model.get_weights() original_idx = 0 for idx, weight in enumerate(quantized_model.weights): if not is_quantization_weight_name(weight.name): if original_idx >= len(original_weight_value): raise ValueError('Not enought original model weights.') weight_values[idx] = original_weight_value[original_idx] original_idx = original_idx + 1 if original_idx < len(original_weight_value): raise ValueError('Not enought quantized model weights.') quantized_model.set_weights(weight_values)
def __init__(self, model: tf.keras.Model): """ Args: model: Model to be sent """ self.config = model.to_json() # Makes a copy of the weights to ensure they are not memoryview objects self.weights = [np.array(v) for v in model.get_weights()]
def get_num_weights(model: tf.keras.Model) -> int: """Utility function to determine the number of weights in a keras model. Arguments: model {tf.keras.Model} -- the keras model Returns: int -- the number of weights """ weights_shape = map(tf.shape, model.get_weights()) return sum(map(tf.reduce_prod, weights_shape))
def polyak_averaging(model: tf.keras.Model, old_weights: list, alpha=0.99): """Source: Deep Learning Book (section 8.7.3) - the original implementation is: `w = alpha * w_old + (1.0 - alpha) * w_new`, here we use `w = alpha * w_new + (1.0 - alpha) * w_old` because it performs better for RL """ new_weights = model.get_weights() weights = [] for w_old, w_new in zip(old_weights, new_weights): w = alpha * w_new + (1.0 - alpha) * w_old weights.append(w) model.set_weights(weights)
def transfer_weights(source_model: tf.keras.Model, target_model: tf.keras.Model): """ Function to transfer weights from trained model to other one Args: source_model: trained `tf.keras.Model` target_model: target `tf.keras.Model` Returns: trained target_model """ target_model.set_weights(source_model.get_weights()) return target_model
def adv_perturbation_closed_form(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor, eps: float = 0.01) -> tf.Tensor: """ Uses the closed form of projected descent for single-layer networks. :param model: keras model with a single layer of weights :param x: input tensor :param y: Tensor with true labels :param eps: amount to scale perturbation by :return: a tensor with adversarial samples """ weights = model.get_weights()[0] y_batch_plus_minus_one = tf.where(tf.equal(y, 1.0), 1.0, -1.0) perturbation = -eps*y_batch_plus_minus_one[:, None] @ tf.transpose( tf.sign(weights)) return x + tf.squeeze(perturbation)
def lr_finder( model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, loss_fn: tf.keras.losses.Loss, dataset, learn_rates: LrGenerator, losses: SmoothedLoss, ) -> Lr: # To run the lr finder directly before training, we need to reset # the initial weights. Here, I do it with building, storing and resetting. model.build(dataset.element_spec[0].shape) weights = model.get_weights() for lr, (source, target) in zip(learn_rates(), dataset): tf.keras.backend.set_value(optimizer.lr, lr) loss = train_step(model, optimizer, loss_fn, source, target).numpy() if losses.no_progress: break losses.update(loss) model.set_weights(weights) return Lr(learn_rates, losses)
def soft_update(target: tf.keras.Model, source: tf.keras.Model, tau): new_weights = [] for target_weights, source_weights in zip(target.get_weights(), source.get_weights()): new_weights.append(target_weights * (1. - tau) + source_weights * tau) target.set_weights(new_weights)
def train_model(model: tf.keras.Model, model_for_validation: tf.keras.Model, list_of_batches: list, data_validate, current_basepath: str): """ trains the model using a custom training loop and early stopping :param model: the model to train :param model_for_validation: an identical model except for the input shape for manual validation :param list_of_batches: training set, as list of equal-length batches, each containing only samples of equal terminal count :param data_validate: validation set :param current_basepath: path to directory to save plots, logs, and teh trained model in :return: nothing, modifies filesystem """ # create plot save folder plot_basepath = os.path.abspath(os.path.join(current_basepath, "plots/")) os.makedirs(plot_basepath) # history of training loss loss_history_train = [] # history of validation loss loss_history_validate = [] # history of "reference training loss" loss_history_train_reference = [] # init individual loss histories individual_loss_histories_validate, final_individual_loss_histories, \ loss_corrected_hpwl = init_loss_histories(data_validate) # init early stopping variables early_stopping_counter = 0 best_val_loss = math.inf early_stopping_checkpoint = None # model epochs for epoch_count in range(epochs): gc.collect() print("epoch #", epoch_count) # shuffle the individual batches (externally) np.random.shuffle(list_of_batches) # loss_history_train.append(0) # in each epoch train on all the training data, one batch at a time for batch_counter in range(len(list_of_batches)): print( "epoch(" + str(epoch_count) + "/" + str(epochs - 1) + ") batch #", batch_counter) # shuffle the specific batch (internally) np.random.shuffle(list_of_batches[batch_counter]) # prepare the data as input to the model x_train = np.asarray([ data.coordinate_pairs for data in list_of_batches[batch_counter] ]) x_train = x_train.reshape( batch_size, len(list_of_batches[batch_counter][0].coordinate_pairs), 2) y_train = np.asarray( [data.true_cost for data in list_of_batches[batch_counter]]) # train on one batch loss_history_train_tmp = model.fit(x_train, y_train, epochs=1) # mean squared error can be summed up and later devided by the number of batches, as long as each batch # is of equal size loss_history_train[epoch_count] += loss_history_train_tmp.history[ 'loss'][0] # finalize mean squared error of train run loss_history_train[epoch_count] /= len(list_of_batches) # copy weights to the validation model with batch size of 1 model_for_validation.set_weights(model.get_weights()) # validate produced_results = [] # reset to 0 (might have been used above) squared_loss_sum = 0 # prepare individual loss records for individual_loss_list_index in range( len(individual_loss_histories_validate)): if not individual_loss_histories_validate[ individual_loss_list_index][1] == 0: individual_loss_histories_validate[individual_loss_list_index][ 0].append(0) # use the whole validation set for data in data_validate: result = model_for_validation.predict( np.asarray(data.coordinate_pairs).reshape( 1, len(data.coordinate_pairs), 2)) # log produced results for plotting (only for first, middle and last iteration because of visual clutter) if ((epoch_count == 0) or (epoch_count == epochs - 1) or (epoch_count == int(epochs / 2)) or (early_stopping_counter > early_stopping_patience)): produced_results.append(result) # compute validation loss if loss_function == 'mean_squared_error': squared_error = (result - data.true_cost) * (result - data.true_cost) if print_high_error_samples and squared_error > 5: print("large error detected:", squared_error) print("expected value:", data.true_cost) print("computed value:", result) print("sequence length:", len(data.coordinate_pairs)) print("sequence:", data.coordinate_pairs) squared_loss_sum += squared_error individual_loss_histories_validate[ len(data.coordinate_pairs) - 2][0][epoch_count] += squared_error else: print("ERROR: loss function not implemented") sys.exit("failed to compute loss function: not implemented.") # finalize loss computation current_val_loss = float(squared_loss_sum / len(data_validate)) # perform early stopping routine if current_val_loss < best_val_loss: print("loss improved from {} to {}".format(best_val_loss, current_val_loss)) # model improved, reset early stopping counter early_stopping_counter = 0 # save current model weights as checkpoint early_stopping_checkpoint = model.get_weights() # update bsf val loss best_val_loss = current_val_loss else: print("loss did not improved from {} (currently {})".format( best_val_loss, current_val_loss)) print("\tpatience left: {}".format(early_stopping_patience - early_stopping_counter)) # model did not improve, increase counter early_stopping_counter += 1 # log the mean loss for this epoch loss_history_validate.append(current_val_loss) # finalize individual mean loss values for this epoch for individual_loss_list_index in range( len(individual_loss_histories_validate)): if not individual_loss_histories_validate[ individual_loss_list_index][1] == 0: individual_loss_histories_validate[individual_loss_list_index][0][epoch_count] /= \ individual_loss_histories_validate[individual_loss_list_index][1] # optionally compute reference train error (= error produced on training set with same Network state used for # calculating validation error squared_loss_sum = 0 if (compute_and_plot_reference_train_error or (epoch_count == epochs - 1) or (early_stopping_counter > early_stopping_patience)): for batch in list_of_batches: for data in batch: result = model_for_validation.predict( np.asarray(data.coordinate_pairs).reshape( 1, len(data.coordinate_pairs), 2)) # compute validation loss if loss_function == 'mean_squared_error': squared_error = (result - data.true_cost) * ( result - data.true_cost) squared_loss_sum += squared_error else: print("ERROR: loss function not implemented") sys.exit( "failed to compute loss function: not implemented." ) # log the mean loss for this epoch loss_history_train_reference.append( float(squared_loss_sum / (len(list_of_batches) * batch_size))) # plot expected and produced results (only for first, middle and last iteration because of visual clutter) if ((epoch_count == 0) or (epoch_count == epochs - 1) or (epoch_count == int(epochs / 2)) or (early_stopping_counter > early_stopping_patience)): plt.title("predictions against true values, epoch " + str(epoch_count)) plt.scatter(range(validation_plotting_point_count) if validation_plotting_point_count < len(produced_results) else range(len(produced_results)), produced_results[:validation_plotting_point_count] if validation_plotting_point_count < len(produced_results) else produced_results, c='r', label='predicted_cost') plt.scatter( range(validation_plotting_point_count) if validation_plotting_point_count < len(data_validate) else range(len(data_validate)), [data.true_cost for data in data_validate][:validation_plotting_point_count] if validation_plotting_point_count < len(data_validate) else [data.true_cost for data in data_validate], c='g', label='true_cost') plt.scatter( range(validation_plotting_point_count) if validation_plotting_point_count < len(data_validate) else range(len(data_validate)), [data.corrected_hpwl for data in data_validate][:validation_plotting_point_count] if validation_plotting_point_count < len(data_validate) else [data.corrected_hpwl for data in data_validate], c='b', label='corrected_hpwl') plt.legend(loc='best') plt.savefig( plot_basepath + ("limited_data_" if data_limit_flag else "full_data_") + "scatterplot_" + ("initial" if (epoch_count == 0) else ("final" if (epoch_count == epochs - 1) else "midway")) + ".png", bbox_inches='tight') plt.show() # plot combined validation loss over epochs plt.title("loss plot for whole validation data set") plt.xlabel("epoch") plt.ylabel(loss_function) plt.yscale('log') plt.plot(range(epoch_count + 1), loss_history_validate, 'bo-', label='validation loss') plt.plot(range(epoch_count + 1), loss_history_train, 'gx-', label='training loss') if (compute_and_plot_reference_train_error or (epoch_count == epochs - 1) or (early_stopping_counter > early_stopping_patience)): if not compute_and_plot_reference_train_error: reference_loss_last_iter = loss_history_train_reference[0] loss_history_train_reference = [] for epoch_count_tmp in range(epoch_count): loss_history_train_reference.append(0) loss_history_train_reference.append(reference_loss_last_iter) plt.plot(range(epoch_count + 1), loss_history_train_reference, 'r*-', label='reference training loss') plt.axhline(y=loss_corrected_hpwl, color='y', linestyle='-', label='corrected_hpwl loss') plt.legend(loc='best') if epoch_count == epochs - 1 or (early_stopping_counter > early_stopping_patience): plt.savefig( plot_basepath + ("limited_data_" if data_limit_flag else "full_data_") + "loss.png", bbox_inches='tight') plt.show() # plot individual loss histories in one combined plot for easy qualitative comparison at the end # if epoch_count == epochs - 1: if (print_and_plot_stats_every_epoch or (epoch_count == epochs - 1) or (early_stopping_counter > early_stopping_patience)): # extract only non-empty histories and assign subplot names column_names = [] filtered_individual_loss_histories = [] filtered_individual_original_loss_histories = [] for individual_loss_list_index in range( len(individual_loss_histories_validate)): if not individual_loss_histories_validate[ individual_loss_list_index][1] == 0: column_names.append(str(individual_loss_list_index + 2)) filtered_individual_loss_histories.append( individual_loss_histories_validate[ individual_loss_list_index][0]) individual_original_loss_tmp = [] for count in range( len(individual_loss_histories_validate[ individual_loss_list_index][0])): individual_original_loss_tmp.append( individual_loss_histories_validate[ individual_loss_list_index][2]) filtered_individual_original_loss_histories.append( individual_original_loss_tmp) # prepare and print the combined plot # prepare data df_individual_loss_histories = pd.DataFrame( np.asarray(filtered_individual_loss_histories).reshape( len(filtered_individual_loss_histories), epoch_count + 1).T, columns=np.asarray(column_names)) df_individual_original_loss_histories = pd.DataFrame( np.asarray( filtered_individual_original_loss_histories).reshape( len(filtered_individual_original_loss_histories), epoch_count + 1).T, columns=np.asarray(column_names)) print("corrected_hpwl loss history:") print(df_individual_original_loss_histories) # create plot fig1, ax1 = plt.subplots(figsize=(20, 15)) # assign different colors to different histories colors = sns.cubehelix_palette( len(filtered_individual_loss_histories), rot=0.9) df_individual_loss_histories.plot(color=colors, ax=ax1, linestyle='-') df_individual_original_loss_histories.plot(color=colors, ax=ax1, linestyle='--') # generate the legend plt.semilogy() plt.legend(ncol=4, loc='best') # specify title and axis labels plt.title( "individual loss histories for validation sequences with different lengths" ) plt.xlabel("epoch") plt.ylabel(loss_function) # if it is the final plot save it (also save a plot just for the initial phase) if (epoch_count == epochs - 1 or epoch_count == int(epochs / 3) or (early_stopping_counter > early_stopping_patience)): final_individual_loss_histories = df_individual_loss_histories plt.savefig( plot_basepath + ("limited_data_" if data_limit_flag else "full_data_") + "loss_individual_" + ("final" if (epoch_count == epochs - 1) or (early_stopping_counter > early_stopping_patience) else "initial") + ".png", bbox_inches='tight') # finally make the plot visible plt.show() print("validation loss history:") print(df_individual_loss_histories) if ((epoch_count == epochs - 1) or (early_stopping_counter > early_stopping_patience)): df_individual_loss_histories.to_csv( os.path.abspath( os.path.join(current_basepath, "individual_loss_histories.csv"))) print("loss history:") print(loss_history_validate) # check if training finished due to early stopping if early_stopping_counter > early_stopping_patience: # patience exhausted, abort training (best model will be saved from checkpointed weights saved) break gc.collect() model_for_validation.set_weights(early_stopping_checkpoint) model_for_validation.save(os.path.abspath( os.path.join(current_basepath, "model")), include_optimizer=False) with open( plot_basepath + ("limited_data_" if data_limit_flag else "full_data_") + "hyperparameters.txt", "w") as text_file: print("batch size: {}".format(batch_size), file=text_file) print("number of epochs: {}".format(epochs), file=text_file) print("loss function: {}".format(loss_function), file=text_file) print("optimizer: {}".format(optimizer_choice), file=text_file) print("numpy random number generator seed: {}".format(np_random_seed), file=text_file) print("data limit factor: {}".format( "none" if not data_limit_flag else data_limit), file=text_file) with open( plot_basepath + ("limited_data_" if data_limit_flag else "full_data_") + "loss_history.txt", "w") as text_file: print("combined validation loss history:", file=text_file) print("", file=text_file) print(loss_history_validate, file=text_file) print("", file=text_file) print("individual validation loss histories:", file=text_file) print("", file=text_file) print(final_individual_loss_histories, file=text_file)
def softsync(target: tf.keras.Model, online: tf.keras.Model): target_weights = target.get_weights() online_weights = online.get_weights() new_weights = tf.nest.map_structure( weighted_sum, target_weights, online_weights) target.set_weights(new_weights)
def update_target_network(net: tf.keras.Model, target_net: tf.keras.Model): net_weights = np.array(net.get_weights()) target_net_weights = np.array(target_net.get_weights()) new_weights = tau * net_weights + (1.0 - tau) * target_net_weights target_net.set_weights(new_weights)
def _train_model(model: tf.keras.Model, database: Dict[str, float], num_epochs: int, test_set: Optional[List[str]], batch_size: int = 32, validation_split: float = 0.1, bootstrap: bool = False, random_state: int = 1, learning_rate: float = 1e-3, patience: int = None, timeout: float = None) -> Union[Tuple[List, dict], Tuple[List, dict, List[float]]]: """Train a model Args: model: Model to be trained database: Training dataset of molecule mapped to a property test_set: Hold-out set. If provided, this function will return predictions on this set num_epochs: Maximum number of epochs to run batch_size: Number of molecules per training batch validation_split: Fraction of molecules used for the training/validation split bootstrap: Whether to perform a bootstrap sample of the dataset random_state: Seed to the random number generator. Ensures entries do not move between train and validation set as the database becomes larger learning_rate: Learning rate for the Adam optimizer patience: Number of epochs without improvement before terminating training. timeout: Maximum training time in seconds Returns: model: Updated weights history: Training history """ # Compile the model with a new optimizer # We find that it is best to reset the optimizer before updating model.compile(tf.keras.optimizers.Adam(lr=learning_rate), 'mean_squared_error') # Separate the database into molecules and properties smiles, y = zip(*database.items()) smiles = np.array(smiles) y = np.array(y) # Make the training and validation splits rng = np.random.RandomState(random_state) train_split = rng.rand(len(smiles)) > validation_split train_X = smiles[train_split] train_y = y[train_split] valid_X = smiles[~train_split] valid_y = y[~train_split] # Perform a bootstrap sample of the training data if bootstrap: sample = rng.choice(len(train_X), size=(len(train_X),), replace=True) train_X = train_X[sample] train_y = train_y[sample] # Make the training data loaders train_loader = GraphLoader(train_X, train_y, batch_size=batch_size, shuffle=True) val_loader = GraphLoader(valid_X, valid_y, batch_size=batch_size, shuffle=False) # Make the callbacks final_learn_rate = 1e-6 init_learn_rate = learning_rate decay_rate = (final_learn_rate / init_learn_rate) ** (1. / (num_epochs - 1)) def lr_schedule(epoch, lr): return lr * decay_rate if patience is None: patience = num_epochs // 8 early_stopping = cb.EarlyStopping(patience=patience, restore_best_weights=True) my_callbacks = [ LRLogger(), EpochTimeLogger(), cb.LearningRateScheduler(lr_schedule), early_stopping, cb.TerminateOnNaN(), train_loader # So the shuffling gets called ] if timeout is not None: my_callbacks += [ TimeLimitCallback(timeout) ] # Run the desired number of epochs history = model.fit(train_loader, epochs=num_epochs, validation_data=val_loader, verbose=False, shuffle=False, callbacks=my_callbacks) # If a timeout is used, make sure we are using the best weights # The training may have exited without storing the best weights if timeout is not None: model.set_weights(early_stopping.best_weights) # Check if there is a NaN loss if np.isnan(history.history['loss']).any(): raise ValueError('Training failed due to a NaN loss.') # If provided, evaluate model on test set test_pred = None if test_set is not None: test_pred = evaluate_mpnn([model], test_set, batch_size, cache=False) # Convert weights to numpy arrays (avoids mmap issues) weights = [] for v in model.get_weights(): v = np.array(v) if np.isnan(v).any(): raise ValueError('Found some NaN weights.') weights.append(v) # Once we are finished training call "clear_session" tf.keras.backend.clear_session() if test_pred is None: return weights, history.history else: return weights, history.history, test_pred[:, 0].tolist()
def warmstart(self, target_model: tf.keras.Model, custom_objects=None): if not self._params.model: logger.debug("No warm start model provided") return # Names that will be ignored in both the loaded and target model (no real weights) names_to_ignore = {"print_limit:0", "count:0"} # 1. Load as saved model. if successful -> 3, if failed -> 2 # 2. If OSError (-> no saved model), load weights directly fron checkpoint -> 5 # 3. Load weights and assign (only works if the model is identical), if failed -> 4 # 4. Load weights by name -> 5 # 5. Apply renaming rules and match weights abs_model_path = os.path.abspath(os.path.expanduser( self._params.model)) try: # 1. Load model src_model = tf.keras.models.load_model( abs_model_path, compile=False, custom_objects=custom_objects) except OSError: # 2. load as checkpoint, then go to 5. logger.debug( f"Could not load '{abs_model_path}' as saved model. Attempting to load as a checkpoint." ) ckpt = tf.train.load_checkpoint(abs_model_path) name_shapes = ckpt.get_variable_to_shape_map() model_to_load_var_names = name_shapes.keys() def rename_ckpt_var_name(name: str): name = name.replace("/.ATTRIBUTES/VARIABLE_VALUE", "") return name names = self._apply_renamings(model_to_load_var_names, self._params.rename) if self._params.add_suffix: names = [name + self._params.add_suffix for name in names] weights_ckpt = { rename_ckpt_var_name(pp_name): ckpt.get_tensor(name) for pp_name, name in zip(names, model_to_load_var_names) } all_loaded_weights = weights_ckpt trainable_name_to_loaded_var_name = {k: k for k in names} else: # 3. apply weights directly logger.info("Source model successfully loaded for warmstart.") skip_direct_apply = self._params.include or self._params.exclude if not skip_direct_apply: try: self.apply_weights(target_model, [ np.asarray(w) for w in src_model.weights if w.name not in names_to_ignore ]) except Exception as e: # 4. Load and rename weights logger.exception(e) logger.warning( "Weights could not be copied directly. Retrying application of renaming rules to" "match variable names.") else: # successful, nothing to do return loaded_var_names = self._apply_renamings( [w.name for w in src_model.weights], self._params.rename) model_to_load_var_names = [w.name for w in src_model.weights] loaded_weights = list( zip(loaded_var_names, src_model.weights, src_model.get_weights())) all_loaded_weights = { name: weight for name, var, weight in loaded_weights if name not in names_to_ignore } trainable_name_to_loaded_var_name = { k: v for k, v in zip( self._apply_renamings(all_loaded_weights.keys(), self._params.rename_targets), all_loaded_weights.keys(), ) } # 5. Apply names with renaming rules target_var_names = self._apply_renamings( [w.name for w in target_model.weights], self._params.rename_targets) target_weights = list( zip(target_var_names, target_model.weights, target_model.get_weights())) if len(set(name for name, var, weight in target_weights)) != len( target_weights): logger.critical( "Non unique names detected in model weight names. You can ignore this warning but the " "model will not be initialized correctly!") all_trainable_target_weights = { name: weight for name, var, weight in target_weights if name not in names_to_ignore } # all_trainable_target_weights = {name: weight for name, var, weight in target_weights if var.trainable} trainable_name_to_target_var_name = { k: v for k, v in zip( self._apply_renamings(all_trainable_target_weights.keys(), self._params.rename_targets), all_trainable_target_weights.keys(), ) } target_var_name_to_trainable_name = { v: k for k, v in trainable_name_to_target_var_name.items() } # Filter the params and validate names_target = set(trainable_name_to_target_var_name.keys()) names_loaded = set(trainable_name_to_loaded_var_name.keys()) if self._params.exclude or self._params.include: names_to_load = names_loaded if self._params.include: inc = re.compile(self._params.include) names_to_load = [ name for name in names_to_load if inc.fullmatch(name) ] if self._params.exclude: exc = re.compile(self._params.exclude) names_to_load = [ name for name in names_to_load if not exc.fullmatch(name) ] if len(names_target.intersection(names_to_load)) == 0: raise NameError( f"Not a weight could be matched.\nLoaded: {names_to_load}\nTarget: {names_target}" ) elif self._params.allow_partial: names_to_load = names_target.intersection(names_loaded) else: names_to_load = names_target diff_target = names_loaded.difference(names_target) diff_load = names_target.difference(names_loaded) if len(diff_target) > 0 or len(diff_load) > 0: raise NameError( f"Not all weights could be matched:\nTargets '{diff_target}'\nLoaded: '{diff_load}'. " f"\nUse allow_partial to allow partial loading") new_weights = [] warm_weights_names = [] cold_weights_names = [] non_trainable_weights_names = [] for weight_idx, name in enumerate(target_var_names): # access original weight via index because names might not be unique (e.g. in metrics) trainable_name = target_var_name_to_trainable_name.get( name, None) # None == not existing if trainable_name in names_to_load: warm_weights_names.append(name) new_weights.append( self._transform_weight( all_loaded_weights[ trainable_name_to_loaded_var_name[trainable_name]], trainable_name, name)) else: if name in all_trainable_target_weights: cold_weights_names.append(name) else: non_trainable_weights_names.append(name) new_weights.append( target_weights[weight_idx][2]) # set to original weight not_loaded_weights = [ name for name in names_loaded if name not in warm_weights_names ] newline = "\n\t" logger.info( newline.join(["model-to-load weights:"] + model_to_load_var_names)) logger.info( newline.join(["renamed unmached weights:"] + [str(x) for x in not_loaded_weights])) logger.info( newline.join(["Warm weights:"] + [str(x) for x in warm_weights_names])) logger.info( newline.join(["Cold weights:"] + [str(x) for x in cold_weights_names])) logger.info( f"There are {len(non_trainable_weights_names)} non trainable weights." ) if len(names_to_load) == 0: raise ValueError( "No warmstart weight could be matched! Set TFAIP_LOG_LEVEL=INFO for more information." ) self.apply_weights(target_model, new_weights)
def hard_update(target: tf.keras.Model, source: tf.keras.Model): target.set_weights(source.get_weights())
def set_model_weights(self, model: tf.keras.Model): self.local_model.set_weights(model.get_weights())