def magic_install(init_args=None): global _run_once if _run_once: return _run_once = True global _magic_config global _import_hook from wandb.integration.keras import WandbCallback # add keras import hooks first # parse config early, before we have wandb.config overrides _magic_config, magic_set = _parse_magic(wandb.env.get_magic()) # we are implicitly enabling magic if _magic_config.get("enable") is None: _magic_config["enable"] = True magic_set["enable"] = True # allow early config to disable magic if not _magic_config.get("enable"): return # process system args _process_system_args() # install argparse wrapper in_jupyter_or_ipython = wandb.wandb_sdk.lib.ipython._get_python_type != "python" if not in_jupyter_or_ipython: _monkey_argparse() # track init calls trigger.register("on_init", _magic_init) # if wandb.init has already been called, this call is ignored init_args = init_args or {} init_args["magic"] = True wandb.init(**init_args) # parse magic from wandb.config (from flattened to dict) magic_from_config = {} MAGIC_KEY = "wandb_magic" for k in wandb.config.keys(): if not k.startswith(MAGIC_KEY + "."): continue d = _dict_from_keyval(k, wandb.config[k], json_parse=False) _merge_dicts(d, magic_from_config) magic_from_config = magic_from_config.get(MAGIC_KEY, {}) _merge_dicts(magic_from_config, _magic_config) # allow late config to disable magic if not _magic_config.get("enable"): return # store magic_set into config if magic_set: wandb.config["magic"] = magic_set wandb.config.persist() # Monkey patch tf.keras if "tensorflow.python.keras" in sys.modules or "keras" in sys.modules: _monkey_tfkeras() # Always setup import hooks looking for keras or tf.keras add_import_hook(fullname="keras", on_import=_monkey_tfkeras) add_import_hook(fullname="tensorflow.python.keras", on_import=_monkey_tfkeras) if "absl.app" in sys.modules: _monkey_absl() else: add_import_hook(fullname="absl.app", on_import=_monkey_absl) # update wandb.config on fit or program finish trigger.register("on_fit", _magic_update_config) trigger.register("on_finished", _magic_update_config)
def _check_keras_version(): import keras keras_version = keras.__version__ major, minor, patch = keras_version.split(".") if int(major) < 2 or int(minor) < 4: wandb.termwarn( "Keras version %s is not fully supported. Required keras >= 2.4.0" % (keras_version) ) if "keras" in sys.modules: _check_keras_version() else: add_import_hook("keras", _check_keras_version) import tensorflow as tf import tensorflow.keras as keras import tensorflow.keras.backend as K tf_logger = tf.get_logger() patch_tf_keras() ### For gradient logging ### class _CustomOptimizer(tf.keras.optimizers.Optimizer):