示例#1
0
def assert_gpr_vs_vgp(
        m1: tf.Module,
        m2: tf.Module,
        gamma: float = 1.0,
        maxiter: int = 1,
        xi_transform: Optional[gpflow.optimizers.natgrad.XiTransform] = None):
    assert maxiter >= 1

    m2_ll_before = m2.log_likelihood()
    m1_ll_before = m1.log_likelihood()

    assert m2_ll_before != m1_ll_before

    @tf.function(autograph=False)
    def loss_cb() -> tf.Tensor:
        return -m2.log_marginal_likelihood()

    params = (m2.q_mu, m2.q_sqrt)
    if xi_transform is not None:
        params += (xi_transform, )

    opt = NaturalGradient(gamma)

    @tf.function(autograph=False)
    def minimize_step():
        opt.minimize(loss_cb, var_list=[params])

    for _ in range(maxiter):
        minimize_step()

    m2_ll_after = m2.log_likelihood()
    m1_ll_after = m1.log_likelihood()

    np.testing.assert_allclose(m1_ll_after, m2_ll_after, atol=1e-4)
def SwitchDevice(model: tf.Module,
                 create_model_method,
                 epoch: int,
                 device='/GPU:0'):
    if device[0] != '/' or 'CPU GPU TPU'.find(device[1:4]) == -1:
        print(
            "ERROR:the device should be like '/GPU:0',and only support CPU GPU TPU"
        )
        return model
    model.stop_training = True
    print('Training of ' + model.name + ' has been stopped')
    str = device[1:6]
    gpus = tf.config.experimental.list_logical_devices(device[1:4])
    for gpu in gpus:
        if (gpu.name.find(str) != -1):
            checkpoint_path = f"training/cache/ckpt/epoch_{epoch}"
            if not os.path.exists(checkpoint_path):
                os.makedirs(checkpoint_path)
            checkpoint_prefix = os.path.join(checkpoint_path, "ckpt")
            model.save_weights(filepath=checkpoint_prefix)
            print('Weights of ' + model.name + ' has been saved in ' +
                  checkpoint_prefix)
            with tf.device(device):
                new_model = create_model_method()
                new_model.load_weights(filepath=checkpoint_prefix)
                new_epoch = epoch - 1
                print(
                    f"New model's device has been swtiched to {device},training is ready to run"
                )
                print(
                    f"WARNING:if you are using epoch-releated lrate,the start epoch should be {new_epoch}"
                )
                return new_model
    print(
        "ERROR:All devices missmatched,anything wrong with the 'device' str?")
示例#3
0
文件: util.py 项目: MelleStarke/BNQD
def compare_modules(source: tf.Module,
                    target: tf.Module,
                    condition: Callable[[object, object], bool],
                    target_types=(gf.Parameter, tf.Variable),
                    path=None):
    """
    Recursively iterates over all tf.Module objects found in the source and the target
    and returns a list of paths for which the condition holds.

    The condition is a function that takes two arguments (obe object from the source and one object from the target)
    and returns a bool. If this bool is True, it appends the path to that object to the result.
    """
    if not (type(source) is type(target)):
        raise ValueError(
            "Source and target module aren't of the same type.\n"
            "\tPath: {}\n\tSource class: {}\n\t Target class: {}".format(
                path, source.__class__.__name__, target.__class__.__name__))

    if path is None:
        path = source.__class__.__name__

    res = list()

    if isinstance(source, target_types):
        if condition(source, target):
            res += [path]

    if isinstance(source, (list, tuple)):
        for i, (sub_source, sub_target) in enumerate(zip(source, target)):
            new_path = f"{path}[{i}]"
            res += compare_modules(sub_source, sub_target, condition,
                                   target_types, new_path)

    elif isinstance(source, dict):
        for (source_key,
             sub_source), (target_key,
                           sub_target) in zip(source.items(), target.items()):
            new_path = f"{path}['{source_key}']"
            res += compare_modules(sub_source, sub_target, condition,
                                   target_types, new_path)

    elif isinstance(source, tf.Module):
        for (source_name, sub_source), (target_name, sub_target) in zip(
                vars(source).items(),
                vars(target).items()):
            if source_name in tf.Module._TF_MODULE_IGNORED_PROPERTIES:
                continue
            new_path = f"{path}.{source_name}"
            res += compare_modules(sub_source, sub_target, condition,
                                   target_types, new_path)
    return res
def _predict(model: tf.Module, dataset: tf.data.Dataset) -> np.ndarray:
    predictions_batches = []
    for x, _y in dataset.as_numpy_iterator():
        prediction_batch = model.predict(x)  # shape (BATCH_SIZE, 1)
        predictions_batches.append(prediction_batch)
    predictions = np.concatenate(predictions_batches)
    return predictions
示例#5
0
def assert_sgpr_vs_svgp(m1: tf.Module, m2: tf.Module):
    data = m1.data

    m1_ll_before = m1.log_likelihood()
    m2_ll_before = m2.log_likelihood(data[0], data[1])

    assert m2_ll_before != m1_ll_before

    @tf.function(autograph=False)
    def loss_cb() -> tf.Tensor:
        return -m2.log_marginal_likelihood(data[0], data[1])

    params = [(m2.q_mu, m2.q_sqrt)]
    opt = NaturalGradient(1.)
    opt.minimize(loss_cb, var_list=params)

    m1_ll_after = m1.log_likelihood()
    m2_ll_after = m2.log_likelihood(data[0], data[1])

    np.testing.assert_allclose(m1_ll_after, m2_ll_after, atol=1e-4)
示例#6
0
def predict(model: tf.Module, text: str, maxlen=10000) -> str:
    data = dataset.Data.from_text(hebrew.iterate_dotted_text(text), maxlen)
    prediction = model.predict(data.normalized)
    [actual_niqqud, actual_dagesh, actual_sin] = [
        dataset.from_categorical(prediction[0]),
        dataset.from_categorical(prediction[1]),
        dataset.from_categorical(prediction[2])
    ]
    actual = merge_unconditional(data.text, data.normalized, actual_niqqud,
                                 actual_dagesh, actual_sin)
    return ' '.join(actual).replace('\ufeff',
                                    '').replace('  ',
                                                ' ').replace(hebrew.RAFE, '')
示例#7
0
def trace_and_update_module(
    module: tf.Module, tf_function: function.Function, name: str,
    strip_control_dependencies: bool) -> function.ConcreteFunction:
  """Traces `tf_function` and saves under attr `name` of `module`.

  Args:
    module: A saveable module which will contain the traced `tf_function` under
      attr `name`.
    tf_function: A tf.function to trace.
    name: A name to same the traced `tf_function` to.
    strip_control_dependencies: Boolean. If True, automatic control dependencies
      will be stripped from the outputs of `tf_function`. This should almost
      always be False. It is useful only if you want to use the structure of the
      TF graph to perform any graph manipulations.

  Returns:
    The concrete function obtained from tracing `tf_function`.
  """
  resource_tracker = tracking.ResourceTracker()
  object_tracker = annotators.ObjectTracker()
  created_variables = []

  def _variable_creator(next_creator, **kwargs):
    var = next_creator(**kwargs)
    created_variables.append(var)
    return var

  # Trace `tf_function` to gather any resources in it using the
  # resource_tracker. These are then assigned to `module.resources` and tracked
  # before exporting to SavedModel.
  with tracking.resource_tracker_scope(resource_tracker), \
       annotators.object_tracker_scope(object_tracker), \
       tf.variable_creator_scope(_variable_creator):
    concrete_fn = tf_function.get_concrete_function()

  # Prior to 2020/10/08, saving a tf.function with a concrete function signature
  # would ensure that the function was not re-traced in a round-trip to a
  # SavedModel. Since this is no longer the case, we save the concrete function
  # directly.
  if tf.compat.forward_compatible(2020, 10, 8):
    pruned_function = optimize_concrete_function(concrete_fn,
                                                 strip_control_dependencies)
    module.pruned_variables = pruned_function.variables
    setattr(module, name, pruned_function)
  else:
    setattr(module, name, tf_function)

  # Any variables created need to be explicitly tracked.
  module.created_variables = created_variables
  # Resources need to be explicitly tracked.
  module.resources = resource_tracker.resources
  module.trackable_objects = object_tracker.trackable_objects
  # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
  # table should be sufficient.
  initializers = []
  for resource in module.resources:
    if isinstance(resource, lookup_ops.InitializableLookupTableBase):
      initializers.append(resource._initializer)  # pylint: disable=protected-access
  module.initializers = initializers
  module.assets = [
      common_types.Asset(asset_filepath) for asset_filepath in
      concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
  ]
  return concrete_fn
示例#8
0
文件: Train.py 项目: kyuno053/cnn
def train_model(p_dataset: DataSet,
                p_model: tf.Module,
                p_optimizer: tf.optimizers,
                logFileName: str,
                betaL2: float,
                interv_reload: int,
                nbIterMax: int,
                min_delta: int,
                patience: int,
                nbElemMoyGlissante: int,
                interv_accuracy: int = 500,
                verbose: int = 1):
    """ Fonction qui gère l'entraînement du modèle, dont le nombre d'itérations max et l'early-stopping """
    train_summary_writer = tf.summary.create_file_writer(
        logFileName
    )  # Crée le fichier de logs pour pouvoir suivre l'évolution dans la tensorboard

    # Gère les niveaux de verbose (0 à 3)
    if verbose <= 0:
        interv_print = nbIterMax
    else:
        interv_print = int(1000 /
                           (10**verbose))  # 1 => 100 // 2 => 10 // 3 => 1

    # Fait des itérations d'entraînement
    earlyStopping_counter = 0
    max_earlyStopping_counter = 0
    l_lastLosses = np.full(shape=(nbElemMoyGlissante),
                           fill_value=999999,
                           dtype=np.float32)
    minSumLosses = sum(l_lastLosses)
    for numIter in range(nbIterMax):
        tf.summary.experimental.set_step(numIter)

        # Affiche et enregistre l'accuracy, la précision, le rappel et la matrice de confusion toutes les 500 itérations
        if numIter % interv_accuracy == 0:
            with train_summary_writer.as_default():
                p_dataset.get_mean_accuracy(p_model, numIter)

        # Entraîne le modèle
        ima, lab = p_dataset.NextTrainingBatch(
        )  # Récupère les données (valeurs des pixels et labels) des images du batch suivant
        with train_summary_writer.as_default(
        ):  # Active l'enregistrement des logs
            loss = train_one_iter(
                p_model, p_optimizer, betaL2, ima, lab, numIter % 10 == 0
            )  # Fait une itération d'entraînement, en enregistrant les logs toutes les 10 iter
            loss += betaL2 * p_model.get_L2_loss()

        # Affiche la perte toutes les *interv_print* itérations
        if numIter % interv_print == 0:
            print(
                "numIter = %6d - loss = %.3f - max_earlyStopping_counter = %d"
                % (numIter, loss, max_earlyStopping_counter))
            max_earlyStopping_counter = 0  # Affiche la valeur maximale de l'earlyStopping_counter depuis le dernier print

        # Early-stopping
        l_lastLosses[numIter % nbElemMoyGlissante] = loss.numpy()
        if minSumLosses - sum(l_lastLosses) < min_delta:
            earlyStopping_counter += 1
            if earlyStopping_counter > max_earlyStopping_counter:
                max_earlyStopping_counter = earlyStopping_counter
        else:
            earlyStopping_counter = 0
            minSumLosses = sum(l_lastLosses)
        if earlyStopping_counter > patience:
            print(
                "\n----- EARLY STOPPING : numIter = %6d - loss = %f - earlyStopping_counter = %d -----"
                % (numIter, loss, earlyStopping_counter))
            break

        # Vide puis recharge le dataset
        if numIter > 0 and numIter % interv_reload == 0:
            p_dataset.reload_fromBIN_lab01()
            p_dataset.get_mean_accuracy(p_model, -1)

    # On finit en beauté par un calcul de l'accuracy
    with train_summary_writer.as_default():
        p_dataset.get_mean_accuracy(p_model, numIter)