예제 #1
0
def magic_install(init_args=None):
    global _run_once
    if _run_once:
        return
    _run_once = True

    global _magic_config
    global _import_hook
    from wandb.integration.keras import WandbCallback  # add keras import hooks first

    # parse config early, before we have wandb.config overrides
    _magic_config, magic_set = _parse_magic(wandb.env.get_magic())

    # we are implicitly enabling magic
    if _magic_config.get("enable") is None:
        _magic_config["enable"] = True
        magic_set["enable"] = True

    # allow early config to disable magic
    if not _magic_config.get("enable"):
        return

    # process system args
    _process_system_args()
    # install argparse wrapper
    in_jupyter_or_ipython = wandb.wandb_sdk.lib.ipython._get_python_type != "python"
    if not in_jupyter_or_ipython:
        _monkey_argparse()

    # track init calls
    trigger.register("on_init", _magic_init)

    # if wandb.init has already been called, this call is ignored
    init_args = init_args or {}
    init_args["magic"] = True
    wandb.init(**init_args)

    # parse magic from wandb.config (from flattened to dict)
    magic_from_config = {}
    MAGIC_KEY = "wandb_magic"
    for k in wandb.config.keys():
        if not k.startswith(MAGIC_KEY + "."):
            continue
        d = _dict_from_keyval(k, wandb.config[k], json_parse=False)
        _merge_dicts(d, magic_from_config)
    magic_from_config = magic_from_config.get(MAGIC_KEY, {})
    _merge_dicts(magic_from_config, _magic_config)

    # allow late config to disable magic
    if not _magic_config.get("enable"):
        return

    # store magic_set into config
    if magic_set:
        wandb.config["magic"] = magic_set
        wandb.config.persist()

    # Monkey patch tf.keras
    if "tensorflow.python.keras" in sys.modules or "keras" in sys.modules:
        _monkey_tfkeras()

    # Always setup import hooks looking for keras or tf.keras
    add_import_hook(fullname="keras", on_import=_monkey_tfkeras)
    add_import_hook(fullname="tensorflow.python.keras",
                    on_import=_monkey_tfkeras)

    if "absl.app" in sys.modules:
        _monkey_absl()
    else:
        add_import_hook(fullname="absl.app", on_import=_monkey_absl)

    # update wandb.config on fit or program finish
    trigger.register("on_fit", _magic_update_config)
    trigger.register("on_finished", _magic_update_config)
예제 #2
0
def _check_keras_version():
    import keras

    keras_version = keras.__version__
    major, minor, patch = keras_version.split(".")
    if int(major) < 2 or int(minor) < 4:
        wandb.termwarn(
            "Keras version %s is not fully supported. Required keras >= 2.4.0"
            % (keras_version)
        )


if "keras" in sys.modules:
    _check_keras_version()
else:
    add_import_hook("keras", _check_keras_version)

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K

tf_logger = tf.get_logger()


patch_tf_keras()


### For gradient logging ###


class _CustomOptimizer(tf.keras.optimizers.Optimizer):