コード例 #1
0
class KerasHook(TensorflowBaseHook, tf.keras.callbacks.Callback):
    def __init__(
        self,
        out_dir,
        export_tensorboard=False,
        tensorboard_dir=None,
        dry_run=False,
        reduction_config=None,
        save_config=None,
        include_regex=None,
        include_collections=None,
        save_all=False,
        include_workers="one",
    ):
        TensorflowBaseHook.__init__(
            self,
            out_dir=out_dir,
            export_tensorboard=export_tensorboard,
            tensorboard_dir=tensorboard_dir,
            init_step=-1,
            dry_run=dry_run,
            reduction_config=reduction_config,
            save_config=save_config,
            include_regex=include_regex,
            include_collections=include_collections,
            save_all=save_all,
            include_workers=include_workers,
        )
        tf.keras.callbacks.Callback.__init__(self)
        self.tensor_refs_to_save_this_step = set()
        self._fetches_added = set()
        self.callable_cache = CallableCache()
        self.custom_tensors_to_save = (
            dict())  # stores tensors custom tensors saved by users every step
        self.saved_layers = dict()
        self.has_registered_model = False
        # supports_tf_logs property was introduced in TF 2.3.0
        # it indicates to the framework that the callback is not
        # limited to reading only numpy logs
        self._supports_tf_logs = True
        # TF 2.3.0 has a callback ordering bug
        # this flag indicated to the train_batch_begin callback
        # the the step was already incremented in the on_train_begin callback
        self.step_incremented_in_on_train_begin = False

    def _is_not_supported(self):
        if self.distribution_strategy is None:
            self.distribution_strategy = self._get_distribution_strategy()
        if self._hook_supported is None:
            self._hook_supported = True
            if not is_tf_version_2x() and (tf.executing_eagerly() or
                                           (hasattr(self.model, "run_eagerly")
                                            and self.model.run_eagerly)):
                self.logger.info(
                    "Disabling SMDebug as it does not support eager mode"
                    "for TF versions 1.x")
                self._hook_supported = False
            elif self.distribution_strategy == TFDistributionStrategy.MIRRORED:
                try:
                    from tensorflow.python.keras.distribute.distributed_training_utils import (
                        get_distributed_model, )
                except ImportError:
                    # for tf1.13 we can't import this, so we can't support mirrored strategy
                    self.logger.info(
                        "Disabling SMDebug as it does not support mirrored strategy"
                        "with TensorFlow version <1.14")
                    self._hook_supported = False
            elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
                self.logger.info(f"Disabling SMDebug as it does not support "
                                 f"{tf.distribute.get_strategy()}")
                self._hook_supported = False
        return not self._hook_supported

    def register_model(self, model):
        # This function is called by the hook in the AWS TF codebase
        # It attaches a hook to every layer of the model to capture
        # layer values
        self.model = model
        if self.tape is not None:
            self._wrap_model_with_input_output_saver()
        self.has_registered_model = True

    def _get_matching_collections(self,
                                  mode,
                                  tensor,
                                  tensor_type,
                                  ts_name,
                                  is_input_to_model=False,
                                  is_output_of_model=False):
        colls_with_tensor = set()
        if tensor_type == "weight":
            if match_inc(
                    tensor.name,
                    self.collection_manager.get(
                        CollectionKeys.BIASES).include_regex):
                colls_with_tensor.add(
                    self.collection_manager.get(CollectionKeys.BIASES))
            else:
                colls_with_tensor.add(
                    self.collection_manager.get(CollectionKeys.WEIGHTS))
        elif is_input_to_model:
            colls_with_tensor.add(
                self.collection_manager.get(CollectionKeys.INPUTS))
        elif is_output_of_model:
            colls_with_tensor.add(
                self.collection_manager.get(CollectionKeys.OUTPUTS))

        for current_coll in self.collection_manager.get_collections().values():
            if current_coll.name in [
                    CollectionKeys.WEIGHTS, CollectionKeys.BIASES
            ]:
                # don't match regex for these as these are added specially above
                # we also don't want users to make mistakes configuring these collections
                continue

            if match_inc(ts_name, current_coll.include_regex):
                # In TF 2.x eager mode, we can't put tensors in a set/dictionary as tensor.__hash__()
                # is no longer available. tensor.experimental_ref() returns a hashable reference
                # object to this Tensor.
                if is_tf_version_2x() and tf.executing_eagerly():
                    # tensor.experimental_ref is an experimental API
                    # and can be changed or removed.
                    # Ref: https://www.tensorflow.org/api_docs/python/tf/Tensor#experimental_ref
                    tensor = tensor.experimental_ref()
                if not current_coll.has_tensor(tensor):
                    # tensor will be added to this coll below
                    colls_with_tensor.add(current_coll)
                # don't recommend adding tensors externally as
                # they will have different internal name
                # but regardless, in such case we only use that tensor name to save data
                # instead of the keras-style-internal-names
        return colls_with_tensor

    def _check_and_add_layer_tensor(self,
                                    mode,
                                    layer,
                                    tensor_type,
                                    tensor,
                                    is_input_to_model=False,
                                    is_output_of_model=False):
        if self.distribution_strategy == TFDistributionStrategy.MIRRORED and not tensor.device:
            # these are extra tensors which show up
            # ignoring this still allows us to access all replica's tensors
            # self.logger.debug(f"Skipping {layer} {tensor_type} {tensor}")
            return

        self._add_to_device_map(tensor)

        tf_names = get_tf_names(tensor)
        # multiple tfnames will only be returned for mirrored variable
        export_name = get_export_name_for_keras(layer, tensor_type, tensor)

        # if there are multiple tf_names, it's for mirrored variable.
        # in that case all the tensor ref objects mapping to tf_name in tensor_to_collections
        # have the same export name, although the objects are different
        # as they tf tensor object for different replica
        if tf_names[0] in self.tensor_to_collections:
            export_name = self._get_tensor_ref(tf_names[0]).export_name
            """
            if this tensor has been added already, it already has a export_name
            we need to use that.
            Cases:
            1. layer0_output0 == layer1_input0
            with this first come first ordering, we will hopefully be considering layer0/outputs/tensorname
            this may not work as intended for non sequential models. need to think of that later

            2. tensor added to collection outside of this prepare call, such as gradients
            there we need to use tfname for export_name

            3. same tensor added to collection in previous mode
            again we want to use previous export name.

            In each of these cases we want to set tensor_ref to be the same object as retrieved.
            """

        colls_with_tensor = self._get_matching_collections(
            mode,
            tensor,
            tensor_type,
            export_name,
            is_input_to_model=is_input_to_model,
            is_output_of_model=is_output_of_model,
        )

        self._create_tensors_for_matching_collections(mode, tensor, tf_names,
                                                      export_name,
                                                      colls_with_tensor)

    def _are_tensors_already_added(self, tf_names):
        # multiple tf_names will be here only for mirrored variable
        seen = 0
        for name in tf_names:
            seen += int(name in self.tensor_to_collections)
        if seen > 1:
            assert seen == len(tf_names)
        return seen > 0

    def _create_tensors_for_matching_collections(self, mode, tensor, tf_names,
                                                 export_name,
                                                 colls_with_tensor):
        # if this tensor was already added to some collection in the previous call
        # do not use it as it is for previous mode
        if colls_with_tensor and not self._are_tensors_already_added(tf_names):
            # need to create new entry in tensor_to_collections dict for the tensor object
            tensor_refs = []
            for coll in colls_with_tensor:
                if not tensor_refs:
                    if isinstance(tensor, tf.Variable):
                        tensor_refs.append(
                            coll.add_variable(tensor,
                                              export_name=export_name,
                                              mode=mode))
                    elif isinstance(tensor, tf.Tensor):
                        tensor_refs.append(
                            coll.add_tensor(tensor,
                                            name=export_name,
                                            mode=mode))
                    elif isinstance(tensor, values.DistributedValues):
                        tensor_refs.extend(
                            coll.add_distributed_variable(
                                tensor, export_name=export_name, mode=mode))
                    else:
                        raise NotImplementedError
                else:
                    # for second collection onwards
                    for t in tensor_refs:
                        coll.set_tensor_ref(t)
            for t in tensor_refs:
                self.tensor_to_collections[t.name] = colls_with_tensor
        elif colls_with_tensor:
            # we should only readd tensors which were already added if these are variables
            # other tensors are part of a different mode, and will cause a crash if fetched
            # because their input placeholders will not be passed.
            if any([
                    c.name in [CollectionKeys.WEIGHTS, CollectionKeys.BIASES]
                    for c in colls_with_tensor
            ]):
                # set mode of the tensorref object for these tensors
                # these are special because they are tf.Variables which require no input
                # they will be present in all modes
                for tf_name in tf_names:
                    tensor_ref = self._get_tensor_ref(tf_name)
                    tensor_ref.add_mode(mode)
        return

    def _get_distributed_model(self, mode):
        # not available in tf 1.13, code shouldn't reach here for 1.13
        # because of _is_not_supported
        from tensorflow.python.keras.distribute.distributed_training_utils import (
            get_distributed_model, )

        return get_distributed_model(self.model, get_keras_mode(mode))

    def _get_model(self, mode):
        if self.distribution_strategy == TFDistributionStrategy.MIRRORED:
            model = self._get_distributed_model(mode)
        else:
            model = self.model
        return model

    def _is_input_layer(self, mode, layer_inputs):
        model_inputs = []
        model = self._get_model(mode)
        # when in mirrored strategy
        if hasattr(model, "values"):
            for per_replica_model in model.values:
                model_inputs.extend(per_replica_model.inputs)
        else:
            model_inputs.extend(model.inputs)
        return any([i in model_inputs for i in layer_inputs])

    def _is_output_layer(self, mode, layer_outputs):
        model_outputs = []
        model = self._get_model(mode)
        # when in mirrored strategy
        if hasattr(model, "values"):
            for per_replica_model in model.values:
                model_outputs.extend(per_replica_model.outputs)
        else:
            model_outputs.extend(model.outputs)
        # In TF 2.X, calling `layer_outputs[0] in model_outputs gives the error:
        # *** tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a
        # `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or
        # decorate this function with @tf.function.
        # Calling `layer_outputs[0] == model_outputs[0]` gives <tf.Tensor 'Equal_1:0'>
        return any([i in model_outputs for i in layer_outputs])

    def _prepare_layers(self, mode):
        # adds any layer tensor (input, output and weight) to appropriate collection
        for layer in self.model.layers:
            # Cannot get input and output tensor values in TF 2.x eager mode.
            # therefore, adding input and output layers only in TF 1.x and
            # TF 2.x non-eager mode.
            if not is_tf_version_2x() or (is_tf_version_2x()
                                          and not tf.executing_eagerly()):
                layer_inputs = get_keras_layer_inputs(layer)
                is_input_layer = self._is_input_layer(mode, layer_inputs)
                for inp in layer_inputs:
                    self._check_and_add_layer_tensor(
                        mode,
                        layer,
                        "input",
                        inp,
                        is_input_to_model=is_input_layer)

                layer_outputs = get_keras_layer_outputs(layer)

                is_output_layer = self._is_output_layer(mode, layer_outputs)
                for outp in layer_outputs:
                    self._check_and_add_layer_tensor(
                        mode,
                        layer,
                        "output",
                        outp,
                        is_output_of_model=is_output_layer)

            # Weights can be retrieved in both
            weights = layer.weights

            for w in weights:
                self._check_and_add_layer_tensor(mode, layer, "weight", w)

    def _prepare_tensors_available_post_step(self):
        # for gradients, optimizer_variables
        custom_collections, _ = self._get_custom_and_default_collections()
        for coll in [
                self.get_collection(name=CollectionKeys.OPTIMIZER_VARIABLES),
                self.get_collection(name=CollectionKeys.GRADIENTS),
                self.get_collection(name=CollectionKeys.OUTPUTS),
                self.get_collection(name=CollectionKeys.INPUTS),
        ]:
            for tensor_ref in coll.get_tensors():
                if tensor_ref.name not in self.tensor_to_collections:
                    self.tensor_to_collections[tensor_ref.name] = {coll}
                elif coll not in self.tensor_to_collections[tensor_ref.name]:
                    self.tensor_to_collections[tensor_ref.name].add(coll)

                # Add tensor to custom collections
                for custom_coll in custom_collections:
                    if (match_inc(tensor_ref.name, custom_coll.include_regex)
                            and tensor_ref.tf_obj is not None):
                        custom_coll.add_for_mode(tensor_ref.tf_obj, self.mode)
                        if custom_coll not in self.tensor_to_collections[
                                tensor_ref.name]:
                            self.tensor_to_collections[tensor_ref.name].add(
                                custom_coll)

    def _prepare_tensors_for_step(self, mode):
        self.tensor_refs_to_save_this_step = set()
        colls_to_save_for_step = self._get_collections_to_save_for_step()
        input_tensors_set = set(
            self.collection_manager.get(
                CollectionKeys.INPUTS).get_tensors(mode=mode))
        for coll in colls_to_save_for_step:
            if coll.name in [
                    CollectionKeys.METRICS, CollectionKeys.LOSSES,
                    CollectionKeys.INPUTS
            ]:
                # these should not be added to fetches, and can be retrieved after the step ends
                continue
            # below fetches even tensors which users might have added manually through collection API
            non_input_tensors = set(
                coll.get_tensors(mode=mode)).difference(input_tensors_set)
            self.tensor_refs_to_save_this_step.update(non_input_tensors)

    def _add_metric(self, metric_name, metric_value: tf.Tensor = None):
        if metric_name in self.tensor_to_collections:
            return

        if metric_name in ["loss", "val_loss"]:
            coll_name = CollectionKeys.LOSSES
        else:
            coll_name = CollectionKeys.METRICS
        coll = self.collection_manager.get(coll_name)
        if metric_value:
            coll.set_tensor_ref(metric_value, metric_name)
        else:
            coll.set_tensor_ref(TensorRef.from_non_graph_var(metric_name))
        self.tensor_to_collections[metric_name] = {coll}

    def _save_custom_tensors_post_step(self):
        # This saves all the values of custom tensors
        # that the user has saved with the save_tensor api
        for tensor_name in self.custom_tensors_to_save:
            tensor_value, collection_names = self.custom_tensors_to_save[
                tensor_name]
            self._save_tensor_to_file(tensor_name, tensor_value,
                                      collection_names)
        self.custom_tensors_to_save.clear()

    def should_save_layer(self, layer_name):
        # Called in AWS TF to determine
        # if a particular layer value
        # should be saved
        return self.should_save_tensor_or_collection(layer_name,
                                                     CollectionKeys.LAYERS)

    def _save_tensor_to_file(self, tensor_name, tensor_value, collections):
        if isinstance(collections, set) is False:
            collections = {collections}
        # Since this function modifies the set, there is a possibility
        # of bugs if calling functions attempt to re-use the set passed
        # to this function
        collections_to_write = collections.copy()
        collections_to_save = self._get_collections_to_save_for_step()
        for c in collections_to_save:
            if match_inc(tensor_name, c.include_regex):
                collections_to_write.add(c)
        self._initialize_writers(only_initialize_if_missing=True)
        tensor_refs = []
        if isinstance(tensor_value, values.PerReplica):
            for t in tensor_value._values:
                tensor_ref = TensorRef.from_non_graph_var(tensor_name)
                tensor_refs.append((tensor_ref, t))
        else:
            tensor_ref = TensorRef.from_non_graph_var(tensor_name)
            tensor_refs.append((tensor_ref, tensor_value))

        for tensor_ref, t in tensor_refs:
            for collection in collections_to_write:
                if isinstance(collection, str):
                    collection = self.get_collection(collection)
                collection.set_tensor_ref(tensor_ref)
            self._save_for_tensor(tensor_name, t, check_before_write=True)

    def save_gradients_from_logs(self, gradients):
        if gradients is not None:
            gradient_collection = self.get_collection(CollectionKeys.GRADIENTS)
            step_collections = self._get_collections_to_save_for_step()
            collections_to_write = ({
                gradient_collection
            } if gradient_collection in step_collections else set())
            if gradients and isinstance(gradients[0], tuple) is False:
                gradients = zip(self.model.trainable_variables, gradients)
            for v, g in gradients:
                if isinstance(v, tf.Tensor):
                    # Tensor.name is meaningless with eager execution
                    layer_name = str(v.numpy(), "utf-8")
                elif isinstance(v, tf.Variable):
                    layer_name = v.name
                elif isinstance(v, bytes):
                    layer_name = str(v, "utf-8")
                else:
                    layer_name = v
                layer_name = layer_name.split(":")[0]
                export_name = "gradients/" + layer_name + "Grad"
                if isinstance(g, IndexedSlices):
                    # This class is a simple wrapper for a pair of Tensor objects
                    # See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices
                    g = g.values
                self._save_tensor_to_file(export_name, g, collections_to_write)

    def save_smdebug_logs(self, logs):
        if logs is None:
            return

        for key in logs:
            tensors_to_save = []
            collections_to_write = set()
            if SMDEBUG_PREFIX in key:
                # Save Model Outputs
                if key in ModelOutputs:
                    export_name = get_model_output_export_name(key)
                    tensors_to_save.append((export_name, logs[key]))
                    collections_to_write = ({
                        self.get_collection(CollectionKeys.OUTPUTS)
                    } if self._is_collection_being_saved_for_step(
                        CollectionKeys.OUTPUTS) else set())
                # Save Gradients
                elif key == SMDEBUG_GRADIENTS_KEY:
                    self.save_gradients_from_logs(logs[key])
                # Save Intermediate Layers
                elif key == SMDEBUG_LAYER_OUTPUTS_KEY:
                    self._save_layer_values(logs[key])
                # Save Model Inputs
                elif key in ModelInputs:
                    export_name = get_model_input_export_name()
                    tensors_to_save.append((export_name, logs[key]))
                    collections_to_write = ({
                        self.get_collection(CollectionKeys.INPUTS)
                    } if self._is_collection_being_saved_for_step(
                        CollectionKeys.INPUTS) else set())
                for t_name, t_value in tensors_to_save:
                    self._save_tensor_to_file(t_name, t_value,
                                              collections_to_write)

    def _save_metrics(self, batch, logs, force_save=False):
        # if force_save is True, doesn't check whether collection needs to be saved for steps
        if logs is None:
            return

        if force_save or self._is_collection_being_saved_for_step(
                CollectionKeys.METRICS):
            self._initialize_writers(only_initialize_if_missing=True)
            logs["batch"] = batch
            for key in logs:
                if key in {"loss", "val_loss", "outputs"} or "smdebug_" in key:
                    # outputs is saved differently through outputs collection
                    continue
                self._add_metric(metric_name=key)
                self._save_for_tensor(key, logs[key], check_before_write=False)

        if force_save or self._is_collection_being_saved_for_step(
                CollectionKeys.LOSSES):
            self._initialize_writers(only_initialize_if_missing=True)
            for key in ["loss", "val_loss"]:
                if key in logs:
                    self._add_metric(metric_name=key)
                    self._save_for_tensor(key,
                                          logs[key],
                                          check_before_write=False)

    def _save_layer_input_and_outputs(self):
        # Run only for GradTape
        if self.tape is None:
            return
        for layer_name in self.saved_layers:
            # Save Input
            tensor = self.saved_layers[layer_name].layer_input
            export_name = get_export_name_for_keras(layer_name,
                                                    tensor_type="input",
                                                    tensor=tensor)
            input_collection = ({self.get_collection(CollectionKeys.LAYERS)}
                                if self._is_collection_being_saved_for_step(
                                    CollectionKeys.LAYERS) else set())
            self._save_tensor_to_file(export_name, tensor.numpy(),
                                      input_collection)
            # Save Output
            tensor = self.saved_layers[layer_name].layer_output
            export_name = get_export_name_for_keras(layer_name,
                                                    tensor_type="output",
                                                    tensor=tensor)
            self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
            output_collection = ({self.get_collection(CollectionKeys.LAYERS)}
                                 if self._is_collection_being_saved_for_step(
                                     CollectionKeys.LAYERS) else set())
            self._save_tensor_to_file(export_name, tensor.numpy(),
                                      output_collection)

    def _save_tensors_post_step(self, batch, logs):
        # some tensors available as value from within hook are saved here
        # weights, metrics
        self._save_metrics(batch, logs)
        self.save_smdebug_logs(logs)
        self._save_custom_tensors_post_step()

        if is_tf_version_2x() and tf.executing_eagerly():
            for tensor_ref in self.tensor_refs_to_save_this_step:
                tensor = tensor_ref.tf_obj
                self._save_for_tensor(tensor_name=tensor.name,
                                      tensor_value=tensor.value(),
                                      check_before_write=False)

    def _get_exec_function(self, mode):
        # exec_function is None in 2.X; self.model exists but has no train_function, test_function, etc.
        if self.distribution_strategy in [
                TFDistributionStrategy.NONE,
                TFDistributionStrategy.HOROVOD,
        ]:
            if mode == ModeKeys.TRAIN:
                x = self.model.train_function
            elif mode == ModeKeys.EVAL:
                x = self.model.test_function
            elif mode == ModeKeys.PREDICT:
                x = self.model.predict_function
            else:
                raise NotImplementedError
        else:
            x = self._get_distributed_model(mode)._distributed_function
        return x

    def _validate_exec_function(self, fn):
        if fn is None:
            self.logger.info(
                f"Could not save tensors for mode {self.mode.name} step {self.mode_steps[self.mode]} "
                f"as execution function has not yet been built.")
            return False
        else:
            return True

    def _save_tensor_callback(self, value, name, check):
        # this function changes the order of args so we can create a partial function for callback
        self._save_for_tensor(tensor_name=name,
                              tensor_value=value,
                              check_before_write=check)

    def _add_callbacks(self, mode):
        # safest if hook callback is the last
        # self.original_fetches = self._get_exec_function(mode).fetches.copy()

        x = self._get_exec_function(mode)  # Returns GraphExecutionFunction
        if self._validate_exec_function(x):
            for tensor_ref in self.tensor_refs_to_save_this_step:
                tensor = tensor_ref.tf_obj
                if tensor not in x.fetches and tensor not in x.fetch_callbacks:
                    x.fetches.append(tensor)
                    self._fetches_added.add(tensor)
                    x.fetch_callbacks[tensor] = functools.partial(
                        self._save_tensor_callback,
                        name=tensor_ref.name,
                        check=False)
                else:
                    self.logger.warning(
                        f"Cannot save tensor {tensor.name} as there is already "
                        f"a callback registered for this tensor. "
                        f"Please remove the existing callback to save this tensor."
                    )

            callable_fn = self.callable_cache.get_fn(mode, x.fetches)
            if callable_fn is not None:
                x._fetches = list(x.fetches)
                x._callable_fn = callable_fn

    def _remove_fetches_and_callbacks(self, mode):
        x = self._get_exec_function(mode)

        # cache the callable for given fetches
        self.callable_cache.cache_fn(mode,
                                     fetches=x.fetches,
                                     callable_fn=x._callable_fn)

        for tf_obj in self._fetches_added:
            x.fetches.remove(tf_obj)
            x.fetch_callbacks.pop(tf_obj)
        self._fetches_added.clear()

    def on_epoch_begin(self, batch, logs=None):
        pass

    def on_epoch_end(self, batch, logs=None):
        if self._is_not_supported():
            return
        self._save_metrics(batch=batch, logs=logs, force_save=True)
        self._close_writers()

    def _on_any_mode_begin(self, mode):
        if self._is_not_supported():
            return
        self.worker = self._get_worker_name()
        self.graph = tf.get_default_graph()
        self.set_mode(mode)

        if self.prepared_collections is False and is_tf_version_2_3_x():
            # Addresses ordering issues in TF 2.3.0
            # sets prepared_collections to True here
            self._prepare_collections()
            self._increment_step()
            self.step_incremented_in_on_train_begin = True

        # have to clear callable cache if we are not caching per mode
        self.callable_cache.change_mode()

    def on_train_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.TRAIN)

    def on_test_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.EVAL)

    # throws error in keras if this fn is absent
    def on_test_end(self, logs=None):
        pass

    # throws error in keras if this fn is absent
    def on_predict_end(self, logs=None):
        pass

    def on_predict_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.PREDICT)

    def _wrap_model_with_input_output_saver(self):
        if self.has_registered_model:
            return
        for layer in self.model.layers:
            layer._hooks = []
            layer.call = get_layer_call_fn(layer)
            layer.register_hook = lambda hook: layer._hooks.append(hook)
            saver = InputOutputSaver()
            layer.register_hook(saver)
            self.saved_layers[layer.name] = saver

    def _on_any_batch_begin(self, batch, mode, logs=None):
        if self._is_not_supported():
            return

        # set mode for each batch as when users run model.fit() and pass validation data
        # through the optional argument, then mode_begin is not called for the training steps
        # after first evaluation during training
        self.set_mode(mode)

        # Write the gradients of the past step if the writer is still available.
        if self.writer is not None or len(self.writer_map):
            self._close_writers()

        # Addresses callback ordering bug in TF 2.3.0
        if self.step_incremented_in_on_train_begin is False:
            self._increment_step()
        else:
            self.step_incremented_in_on_train_begin = False

        if self.prepared_collections is False:
            # sets prepared_collections to True here
            self._prepare_collections()

        if self._prepared_tensors[mode] is False:
            if (is_tf_version_2x() and
                    tf.executing_eagerly()) or self._validate_exec_function(
                        self._get_exec_function(mode)):
                self._prepare_layers(mode)
                self._prepare_tensors_available_post_step()
                self._prepared_tensors[mode] = True
                # below should be after tensors are processed,
                # so we know that device map is populated
                self._set_chief_worker()
            # else:
            # this will delay the preparation of tensors as the
            # full graph is not built. Gradients are not available
            # at this stage for example

        if self._prepared_tensors[mode]:
            self._prepare_tensors_for_step(mode)
            if self.tensor_refs_to_save_this_step:
                # if saving metric, writer may not be initialized as a result
                self._initialize_writers()

            if not is_tf_version_2x() or (is_tf_version_2x()
                                          and not tf.executing_eagerly()):
                self._add_callbacks(mode)

    def on_train_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.TRAIN, logs=logs)

    def on_test_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.EVAL, logs=logs)

    def on_predict_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.PREDICT, logs=logs)

    def _save_layer_values(self, logs):
        if logs is None:
            return
        step_collections = self._get_collections_to_save_for_step()
        layer_collection = self.get_collection(CollectionKeys.LAYERS)
        collections_to_write = {
            layer_collection
        } if layer_collection in step_collections else set()
        for layer_name, layer_input, layer_output in logs:
            # Cast layer_name to str since it can also be of type bytes
            # when run with mirrored strategy
            if len(layer_input) == 1:
                # Layer Inputs are flattened and passed as a list into
                # the next layer. Unpacking it speeds up the _make_numpy fn.
                layer_input = layer_input[0]
            layer_input_tensor_name = get_export_name_for_keras(
                str(layer_name), "input")
            self._save_tensor_to_file(layer_input_tensor_name, layer_input,
                                      collections_to_write)
            layer_output_tensor_name = get_export_name_for_keras(
                str(layer_name), "output")
            self._save_tensor_to_file(layer_output_tensor_name, layer_output,
                                      collections_to_write)

    def _write_optimizer_variables(self):
        optimizer_collections = self.collection_manager.get(
            CollectionKeys.OPTIMIZER_VARIABLES)
        collections_to_save = self._get_collections_to_save_for_step()
        for tensor_ref in optimizer_collections.get_tensors(
                mode=ModeKeys.TRAIN):
            tensor = tensor_ref.tf_obj
            collections_to_save = self._get_collections_with_tensor(
                tensor.name).intersection(collections_to_save)
            if len(collections_to_save):
                self._initialize_writers(only_initialize_if_missing=True)
                tensor = tensor_ref.tf_obj
                self._add_to_device_map(tensor)
                tf_names = get_tf_names(tensor)
                for name in tf_names:
                    self._save_for_tensor(tensor_name=name,
                                          tensor_value=tensor.value(),
                                          check_before_write=False)

    def _on_any_batch_end(self, batch, mode, logs=None):
        if self._is_not_supported():
            return

        if not is_tf_version_2x() or (is_tf_version_2x()
                                      and not tf.executing_eagerly()):
            self._remove_fetches_and_callbacks(mode)

        self._save_tensors_post_step(batch, logs)
        if is_tf_version_2x() and tf.executing_eagerly():
            # Need to prepare non layer tensors again since
            # some tensors only become available on  batch end
            self._prepare_tensors_available_post_step()
            self._write_optimizer_variables()

        if self._prepared_tensors[mode]:
            if self._exported_collections is False:
                # in keras, these collections change when mode changes
                # but rest of the project isn't yet capable of handling this
                # this means that collections like outputs, or other collections with intermediate tensors
                # will only have tensor names from first mode

                # this means sometimes collections will be exported after 1 step
                self.export_collections()
                self._exported_collections = True

            if self._exported_model[self.mode] is False:
                # confirmed that keras has same graph for all modes
                # but we are writing it multiple times to keep behavior consistent with
                # estimator and to make it easier when seeing tensorboard
                self._export_model()
                self._exported_model[self.mode] = True

    def on_train_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.TRAIN, logs=logs)

    def on_test_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.EVAL, logs=logs)

    def on_predict_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.PREDICT, logs=logs)

    def wrap_optimizer(self, optimizer):
        """
        Wrapping your optimizer with this method enables finding gradient tensors and optimizer
        variables.

        :param optimizer: tf.train.Optimizer or tf.keras.optimizers.Optimizer
            the optimizer object used for training
        :return: Wrapped optimizer of same type as passed.
            This optimizer should be used for training
        """
        if isinstance(optimizer, tf.train.Optimizer):
            optimizer = self._wrap_apply_gradients(optimizer)
        elif isinstance(optimizer, tf.keras.optimizers.Optimizer
                        ) or is_keras_optimizer(optimizer):
            # either subclasse of optimizerV2 class in tf.keras
            # or keras.optimizers.Optimizer
            original_get_grads = optimizer.__class__.get_gradients

            def new_get_grads(opt, loss, params):
                grads = original_get_grads(opt, loss, params)
                self.set_gradients(gradients=grads)
                return grads

            optimizer.__class__.get_gradients = new_get_grads

            if isinstance(optimizer, tf.keras.optimizers.Optimizer):
                try:
                    original_add_weight = optimizer.__class__.add_weight

                    def new_add_weight(opt, *args, **kwargs):
                        var = original_add_weight(opt, *args, **kwargs)
                        self.set_optimizer_variables(var)
                        return var

                    optimizer.__class__.add_weight = new_add_weight
                except AttributeError:
                    # TF 1.13 Keras Optimizers have no add_weight attribute,
                    # so optimizer_variables is not supported
                    pass
        else:
            self._log_unsupported_optimizer(optimizer)
        # Optimizer is being saved to support additional features in the future.
        self.optimizer = optimizer
        return optimizer

    def _log_unsupported_tape(self, tape):
        self.logger.warning(
            f"Unsupported tape {tape} {tape.__class__}, cannot automatically find "
            "gradients, loss, weights, and biases.")

    def _unwrap_tape(self):
        """
        Unwrap the wrapped tape. Not doing so on hook cleanup or close,
        will lead to recursive wrapping when there are more tapes in the
        training script.
        """
        def _is_wrapper(f):
            return hasattr(f, "__wrapped__")

        def unwrap(func):
            while _is_wrapper(func):
                func = func.__wrapped__
            return func

        self.tape.__class__._push_tape = unwrap(self.tape.__class__._push_tape)
        self.tape.__class__._pop_tape = unwrap(self.tape.__class__._pop_tape)
        self.tape.__class__.gradient = unwrap(self.tape.__class__.gradient)

    def _cleanup(self):
        # Unwrap the tape before closing
        if self.tape:
            self._unwrap_tape()
        super()._cleanup()

    def _wrap_push_tape(self, function):
        """
        tape._push_tape is called at the beginning of the GradientTape block.
        Using this wrapper to prepare collections, initialize writers, and
        increment step.
        """
        @functools.wraps(function)
        def run(*args, **kwargs):
            function(*args, **kwargs)
            if self._is_not_supported():
                return

            self.worker = self._get_worker_name()

            if self.writer is not None or len(self.writer_map):
                self._save_custom_tensors_post_step()
                self._close_writers()

            if not self.prepared_collections:
                # at this point we need all collections to be ready
                # this may not be the case at creation of hook
                # as user's code after hook might add collections
                self.collection_manager.get(CollectionKeys.WEIGHTS).include(
                    "^weights/.*/((?!bias).)*$")
                self.collection_manager.get(
                    CollectionKeys.LOSSES).include(".*loss.*")
                self.collection_manager.get(
                    CollectionKeys.GRADIENTS).include("^gradient")
                self._prepare_collections()
                self.prepared_collections = True

            self._increment_step()

            if self._get_collections_to_save_for_step():
                self._initialize_writers()

            if self.last_saved_step is not None and self._exported_collections is False:
                # in keras, these collections change when mode changes
                # but rest of the project isn't yet capable of handling this
                # this means that collections like outputs, or other collections with intermediate tensors
                # will only have tensor names from first mode

                # this means sometimes collections will be exported after 1 step
                self.export_collections()
                self._exported_collections = True

        return run

    def _wrap_tape_gradient(self, function):
        """
        tape.gradient() is used to compute gradients from loss and model variables.
        Using this wrapper to get gradients, loss, weights, and bias values.
        """
        @functools.wraps(function)
        def run(*args, **kwargs):
            grads = function(*args, **kwargs)
            if self._is_not_supported():
                return grads
            loss = args[1]
            vars = args[2]
            if ((not grads or not vars) or
                (not isinstance(grads, list) or not isinstance(vars, list))
                    or (not ((isinstance(vars[0], tf.Variable))
                             and hasattr(vars[0], "numpy")))
                    or (not ((isinstance(grads[0], tf.Tensor))
                             and hasattr(grads[0], "numpy")))):
                return grads

            if self._get_collections_to_save_for_step():
                for (g, v) in zip(grads, vars):
                    layer = v.name.split(":")[0]
                    # Adding a check to make sure gradients are not None.
                    # gradients may be None if user tries to compute gradients for
                    # non-training variable when using model.variables instead of
                    # model.trainable_variables in tape.gradient().
                    # model.variables includes trainable and non-trainable
                    # variables.
                    if g is not None:
                        self._save_for_tensor(
                            tensor_name="gradients/" + layer + "Grad",
                            tensor_value=g,
                            check_before_write=True,
                        )
                    self._save_for_tensor(
                        tensor_name="weights/" + v.name,
                        tensor_value=v.value(),
                        check_before_write=True,
                    )

            self._write_optimizer_variables()
            self._save_layer_input_and_outputs()
            if not ((isinstance(loss, tf.Tensor)) and hasattr(loss, "numpy")):
                return grads
            self._add_metric(metric_name="loss", metric_value=loss)
            if self._is_collection_being_saved_for_step(CollectionKeys.LOSSES):
                self._initialize_writers(only_initialize_if_missing=True)
                self._save_for_tensor("loss", loss, check_before_write=False)

            return grads

        return run

    def _wrap_pop_tape(self, function):
        """
        tape._pop_tape() is called at the end of a GradientTape execution.
        Using this to export collections
        """
        @functools.wraps(function)
        def run(*args, **kwargs):
            function(*args, **kwargs)
            if self._is_not_supported():
                return

            self.last_saved_step = self.step

        return run

    def save_tape_logs(self, model_inputs=None, outputs=None):
        """
        called by AWS TF to save model inputs and outputs
        :param model_inputs:
        :param outputs:
        :return:
        """
        logs = {
            ModelOutput.PREDICTIONS: outputs,
            ModelInput.INPUTS: model_inputs
        }
        self.save_smdebug_logs(logs)

    def wrap_tape(self, tape):
        """
        Wrapping your GradientTape with this method enables finding gradient tensors and optimizer
        variables.

        :param tape: tensorflow.python.eager.backprop.GradientTape
            the tape object used for training
        :return: Wrapped tape of same type as passed.
            This tape should be used for training
        """
        from tensorflow.python.eager.backprop import GradientTape

        if isinstance(tape, GradientTape):
            # unwrap tape before wrapping new tape to avoid recursive wrap tapes
            if self.tape:
                self._unwrap_tape()

            self.tape = tape
            self.tape.__class__._push_tape = self._wrap_push_tape(
                tape.__class__._push_tape)
            self.tape.__class__.gradient = self._wrap_tape_gradient(
                tape.__class__.gradient)
            self.tape.__class__._pop_tape = self._wrap_pop_tape(
                tape.__class__._pop_tape)
        else:
            self._log_unsupported_tape(tape)
        return tape

    def record_tensor_value(self, tensor_name, tensor_value):
        # To be used to save metrics of type EagerTensor
        if (not ((isinstance(tensor_value, tf.Tensor)) and hasattr(
                tensor_value, "numpy"))) or self._is_not_supported():
            return

        self.logger.warning(
            "This function has been deprecated. Please use the save_tensor API "
        )

        self._add_metric(metric_name=tensor_name, metric_value=tensor_value)
        if self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
            self._initialize_writers(only_initialize_if_missing=True)
            self._save_for_tensor(tensor_name,
                                  tensor_value,
                                  check_before_write=False)
コード例 #2
0
class KerasHook(TensorflowBaseHook, tf.keras.callbacks.Callback):
    def __init__(
        self,
        out_dir,
        export_tensorboard=False,
        tensorboard_dir=None,
        dry_run=False,
        reduction_config=None,
        save_config=None,
        include_regex=None,
        include_collections=None,
        save_all=False,
        include_workers="one",
    ):
        super().__init__(
            out_dir=out_dir,
            export_tensorboard=export_tensorboard,
            tensorboard_dir=tensorboard_dir,
            init_step=-1,
            dry_run=dry_run,
            reduction_config=reduction_config,
            save_config=save_config,
            include_regex=include_regex,
            include_collections=include_collections,
            save_all=save_all,
            include_workers=include_workers,
        )
        self._exported_collections = False
        self._exported_model = {
            ModeKeys.TRAIN: False,
            ModeKeys.EVAL: False,
            ModeKeys.PREDICT: False,
        }
        self.tensor_refs_to_save_this_step = set()
        self._fetches_added = set()
        self._prepared_tensors = {
            ModeKeys.TRAIN: False,
            ModeKeys.EVAL: False,
            ModeKeys.PREDICT: False,
        }
        self.callable_cache = CallableCache()

    def _is_not_supported(self):
        if self._hook_supported is None:
            self._hook_supported = True
            if tf.executing_eagerly() or (hasattr(self.model, "run_eagerly")
                                          and self.model.run_eagerly):
                self.logger.info(
                    "Disabling SMDebug as it does not support eager mode")
                self._hook_supported = False
            elif self._get_distribution_strategy(
            ) == TFDistributionStrategy.MIRRORED_STRATEGY:
                try:
                    from tensorflow.python.keras.distribute.distributed_training_utils import (
                        get_distributed_model, )
                except ImportError:
                    # for tf1.13 we can't import this, so we can't support mirrored strategy
                    self.logger.info(
                        "Disabling SMDebug as it does not support mirrored strategy"
                        "with TensorFlow version <1.14")
                    self._hook_supported = False
            elif self._get_distribution_strategy(
            ) == TFDistributionStrategy.UNSUPPORTED:
                self.logger.info(f"Disabling SMDebug as it does not support "
                                 f"{tf.distribute.get_strategy()}")
                self._hook_supported = False
        return not self._hook_supported

    def _get_matching_collections(self,
                                  mode,
                                  tensor,
                                  tensor_type,
                                  ts_name,
                                  is_input_to_model=False,
                                  is_output_of_model=False):
        colls_with_tensor = set()
        if tensor_type == "weight":
            if match_inc(
                    tensor.name,
                    self.collection_manager.get(
                        CollectionKeys.BIASES).include_regex):
                colls_with_tensor.add(
                    self.collection_manager.get(CollectionKeys.BIASES))
            else:
                colls_with_tensor.add(
                    self.collection_manager.get(CollectionKeys.WEIGHTS))
        elif is_input_to_model:
            colls_with_tensor.add(
                self.collection_manager.get(CollectionKeys.INPUTS))
        elif is_output_of_model:
            colls_with_tensor.add(
                self.collection_manager.get(CollectionKeys.OUTPUTS))

        for current_coll in self.collection_manager.get_collections().values():
            if current_coll.name in [
                    CollectionKeys.WEIGHTS, CollectionKeys.BIASES
            ]:
                # don't match regex for these as these are added specially above
                # we also don't want users to make mistakes configuring these collections
                continue

            if match_inc(ts_name, current_coll.include_regex):
                if not current_coll.has_tensor(tensor):
                    # tensor will be added to this coll below
                    colls_with_tensor.add(current_coll)
                # don't recommend adding tensors externally as
                # they will have different internal name
                # but regardless, in such case we only use that tensor name to save data
                # instead of the keras-style-internal-names
        return colls_with_tensor

    def _check_and_add_layer_tensor(self,
                                    mode,
                                    layer,
                                    tensor_type,
                                    tensor,
                                    is_input_to_model=False,
                                    is_output_of_model=False):
        if (self.distribution_strategy
                == TFDistributionStrategy.MIRRORED_STRATEGY
                and not tensor.device):
            # these are extra tensors which show up
            # ignoring this still allows us to access all replica's tensors
            # self.logger.debug(f"Skipping {layer} {tensor_type} {tensor}")
            return

        self._add_to_device_map(tensor)

        tf_names = get_tf_names(tensor)
        # multiple tfnames will only be returned for mirrored variable
        export_name = get_export_name_for_keras(layer, tensor_type, tensor)

        # if there are multiple tf_names, it's for mirrored variable.
        # in that case all the tensor ref objects mapping to tf_name in tensor_to_collections
        # have the same export name, although the objects are different
        # as they tf tensor object for different replica
        if tf_names[0] in self.tensor_to_collections:
            export_name = self._get_tensor_ref(tf_names[0]).export_name
            """
            if this tensor has been added already, it already has a export_name
            we need to use that.
            Cases:
            1. layer0_output0 == layer1_input0
            with this first come first ordering, we will hopefully be considering layer0/outputs/tensorname
            this may not work as intended for non sequential models. need to think of that later

            2. tensor added to collection outside of this prepare call, such as gradients
            there we need to use tfname for export_name

            3. same tensor added to collection in previous mode
            again we want to use previous export name.

            In each of these cases we want to set tensor_ref to be the same object as retrieved.
            """

        colls_with_tensor = self._get_matching_collections(
            mode,
            tensor,
            tensor_type,
            export_name,
            is_input_to_model=is_input_to_model,
            is_output_of_model=is_output_of_model,
        )

        self._create_tensors_for_matching_collections(mode, tensor, tf_names,
                                                      export_name,
                                                      colls_with_tensor)

    def _are_tensors_already_added(self, tf_names):
        # multiple tf_names will be here only for mirrored variable
        seen = 0
        for name in tf_names:
            seen += int(name in self.tensor_to_collections)
        if seen > 1:
            assert seen == len(tf_names)
        return seen > 0

    def _create_tensors_for_matching_collections(self, mode, tensor, tf_names,
                                                 export_name,
                                                 colls_with_tensor):
        # if this tensor was already added to some collection in the previous call
        # do not use it as it is for previous mode
        if colls_with_tensor and not self._are_tensors_already_added(tf_names):
            # need to create new entry in tensor_to_collections dict for the tensor object
            tensor_refs = []
            for coll in colls_with_tensor:
                if not tensor_refs:
                    if isinstance(tensor, tf.Variable):
                        tensor_refs.append(
                            coll.add_variable(tensor,
                                              export_name=export_name,
                                              mode=mode))
                    elif isinstance(tensor, tf.Tensor):
                        tensor_refs.append(
                            coll.add_tensor(tensor,
                                            name=export_name,
                                            mode=mode))
                    elif isinstance(tensor, values.MirroredVariable):
                        tensor_refs.extend(
                            coll.add_mirrored_variable(tensor,
                                                       export_name=export_name,
                                                       mode=mode))
                    else:
                        raise NotImplementedError
                else:
                    # for second collection onwards
                    for t in tensor_refs:
                        coll.set_tensor_ref(t)
            for t in tensor_refs:
                self.tensor_to_collections[t.name] = colls_with_tensor
        elif colls_with_tensor:
            # we should only readd tensors which were already added if these are variables
            # other tensors are part of a different mode, and will cause a crash if fetched
            # because their input placeholders will not be passed.
            if any([
                    c.name in [CollectionKeys.WEIGHTS, CollectionKeys.BIASES]
                    for c in colls_with_tensor
            ]):
                # set mode of the tensorref object for these tensors
                # these are special because they are tf.Variables which require no input
                # they will be present in all modes
                for tf_name in tf_names:
                    tensor_ref = self._get_tensor_ref(tf_name)
                    tensor_ref.add_mode(mode)
        return

    def _get_distributed_model(self, mode):
        # not available in tf 1.13, code shouldn't reach here for 1.13
        # because of _is_not_supported
        from tensorflow.python.keras.distribute.distributed_training_utils import (
            get_distributed_model, )

        return get_distributed_model(self.model, get_keras_mode(mode))

    def _is_input_layer(self, mode, layer_inputs):
        model_inputs = []
        if self.distribution_strategy == TFDistributionStrategy.MIRRORED_STRATEGY:
            model = self._get_distributed_model(mode)
        else:
            model = self.model
        # when in mirrored strategy
        if hasattr(model, "values"):
            for per_replica_model in model.values:
                model_inputs.extend(per_replica_model.inputs)
        else:
            model_inputs.extend(model.inputs)
        return any([i in model_inputs for i in layer_inputs])

    def _is_output_layer(self, mode, layer_outputs):
        model_outputs = []
        if self.distribution_strategy == TFDistributionStrategy.MIRRORED_STRATEGY:
            model = self._get_distributed_model(mode)
        else:
            model = self.model
        # when in mirrored strategy
        if hasattr(model, "values"):
            for per_replica_model in model.values:
                model_outputs.extend(per_replica_model.outputs)
        else:
            model_outputs.extend(model.outputs)
        return any([i in model_outputs for i in layer_outputs])

    def _prepare_layers(self, mode):
        # adds any layer tensor (input, output and weight) to appropriate collection
        for layer in self.model.layers:
            layer_inputs = get_keras_layer_inputs(layer)
            is_input_layer = self._is_input_layer(mode, layer_inputs)
            for inp in layer_inputs:
                self._check_and_add_layer_tensor(
                    mode,
                    layer,
                    "input",
                    inp,
                    is_input_to_model=is_input_layer)

            layer_outputs = get_keras_layer_outputs(layer)

            is_output_layer = self._is_output_layer(mode, layer_outputs)
            for outp in layer_outputs:
                self._check_and_add_layer_tensor(
                    mode,
                    layer,
                    "output",
                    outp,
                    is_output_of_model=is_output_layer)

            weights = layer.weights
            for w in weights:
                self._check_and_add_layer_tensor(mode, layer, "weight", w)

        self._prepared_tensors[mode] = True

    def _prepare_non_layer_tensors(self):
        # for gradients, optimizer_variables
        for coll in self.collection_manager.get_collections().values():
            for tensor_ref in coll.get_tensors():
                if tensor_ref.name not in self.tensor_to_collections:
                    self.tensor_to_collections[tensor_ref.name] = {coll}
                elif coll not in self.tensor_to_collections[tensor_ref.name]:
                    self.tensor_to_collections[tensor_ref.name].add(coll)

    def _prepare_tensors_for_step(self, mode):
        self.tensor_refs_to_save_this_step = set()
        colls_to_save_for_step = self._get_collections_to_save_for_step()
        input_tensors_set = set(
            self.collection_manager.get(
                CollectionKeys.INPUTS).get_tensors(mode=mode))
        for coll in colls_to_save_for_step:
            if coll.name in [
                    CollectionKeys.METRICS, CollectionKeys.LOSSES,
                    CollectionKeys.INPUTS
            ]:
                # these should not be added to fetches, and can be retrieved after the step ends
                continue
            # below fetches even tensors which users might have added manually through collection API
            non_input_tensors = set(
                coll.get_tensors(mode=mode)).difference(input_tensors_set)
            self.tensor_refs_to_save_this_step.update(non_input_tensors)

    def _save_inputs(self, check_before_write=True):
        # TODO
        pass

    def _add_metric(self, metric_name):
        if metric_name in self.tensor_to_collections:
            return

        if metric_name in ["loss", "val_loss"]:
            coll_name = CollectionKeys.LOSSES
        else:
            coll_name = CollectionKeys.METRICS
        coll = self.collection_manager.get(coll_name)
        coll.set_tensor_ref(TensorRef.from_non_graph_var(metric_name))
        self.tensor_to_collections[metric_name] = {coll}

    def _save_metrics(self, batch, logs, force_save=False):
        # if force_save is True, doesn't check whether collection needs to be saved for steps
        if logs is None:
            return

        if force_save or self._is_collection_being_saved_for_step(
                CollectionKeys.METRICS):
            self._initialize_writers(only_initialize_if_missing=True)
            logs["batch"] = batch
            for key in logs:
                if key in ["loss", "val_loss", "outputs"]:
                    # outputs is saved differently through outputs collection
                    continue
                self._add_metric(metric_name=key)
                self._save_for_tensor(key, logs[key], check_before_write=False)

        if force_save or self._is_collection_being_saved_for_step(
                CollectionKeys.LOSSES):
            self._initialize_writers(only_initialize_if_missing=True)
            for key in ["loss", "val_loss"]:
                if key in logs:
                    self._add_metric(metric_name=key)
                    self._save_for_tensor(key,
                                          logs[key],
                                          check_before_write=False)

    def _save_tensors_post_step(self, batch, logs):
        # some tensors available as value from within hook are saved here
        # weights, metrics
        self._save_metrics(batch, logs)

        if self._is_collection_being_saved_for_step(CollectionKeys.INPUTS):
            self._save_inputs(check_before_write=False)

    def _get_exec_function(self, mode):
        if self.distribution_strategy in [
                TFDistributionStrategy.NONE,
                TFDistributionStrategy.HOROVOD,
        ]:
            if mode == ModeKeys.TRAIN:
                x = self.model.train_function
            elif mode == ModeKeys.EVAL:
                x = self.model.test_function
            elif mode == ModeKeys.PREDICT:
                x = self.model.predict_function
            else:
                raise NotImplementedError
        else:
            x = self._get_distributed_model(mode)._distributed_function
        return x

    def _validate_exec_function(self, fn):
        if fn is None:
            self.logger.info(
                f"Could not save tensors for mode {self.mode.name} step {self.mode_steps[self.mode]} "
                f"as execution function has not yet been built.")
            return False
        else:
            return True

    def _save_tensor_callback(self, value, name, check):
        # this function changes the order of args so we can create a partial function for callback
        self._save_for_tensor(tensor_name=name,
                              tensor_value=value,
                              check_before_write=check)

    def _add_callbacks(self, mode):
        # safest if hook callback is the last
        # self.original_fetches = self._get_exec_function(mode).fetches.copy()

        x = self._get_exec_function(mode)  # Returns GraphExecutionFunction
        if self._validate_exec_function(x):
            for tensor_ref in self.tensor_refs_to_save_this_step:
                tensor = tensor_ref.tf_obj
                if tensor not in x.fetches and tensor not in x.fetch_callbacks:
                    x.fetches.append(tensor)
                    self._fetches_added.add(tensor)
                    x.fetch_callbacks[tensor] = functools.partial(
                        self._save_tensor_callback,
                        name=tensor_ref.name,
                        check=False)
                else:
                    self.logger.warning(
                        f"Cannot save tensor {tensor.name} as there is already "
                        f"a callback registered for this tensor. "
                        f"Please remove the existing callback to save this tensor."
                    )

            callable_fn = self.callable_cache.get_fn(mode, x.fetches)
            if callable_fn is not None:
                x._fetches = list(x.fetches)
                x._callable_fn = callable_fn

    def _remove_fetches_and_callbacks(self, mode):
        x = self._get_exec_function(mode)

        # cache the callable for given fetches
        self.callable_cache.cache_fn(mode,
                                     fetches=x.fetches,
                                     callable_fn=x._callable_fn)

        for tf_obj in self._fetches_added:
            x.fetches.remove(tf_obj)
            x.fetch_callbacks.pop(tf_obj)
        self._fetches_added.clear()

    def on_epoch_begin(self, batch, logs=None):
        pass

    def on_epoch_end(self, batch, logs=None):
        if self._is_not_supported():
            return
        self._save_metrics(batch=batch, logs=logs, force_save=True)
        self._close_writers()

    def _on_any_mode_begin(self, mode):
        if self._is_not_supported():
            return
        self.distribution_strategy = self._get_distribution_strategy()
        self.worker = self._get_worker_name()
        self.graph = tf.get_default_graph()
        self.set_mode(mode)

        # have to clear callable cache if we are not caching per mode
        self.callable_cache.change_mode()

    def on_train_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.TRAIN)

    def on_test_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.EVAL)

    # throws error in keras if this fn is absent
    def on_test_end(self, logs=None):
        pass

    # throws error in keras if this fn is absent
    def on_predict_end(self, logs=None):
        pass

    def on_predict_begin(self, logs=None):
        self._on_any_mode_begin(ModeKeys.PREDICT)

    def _on_any_batch_begin(self, batch, mode, logs=None):
        if self._is_not_supported():
            return

        # set mode for each batch as when users run model.fit() and pass validation data
        # through the optional argument, then mode_begin is not called for the training steps
        # after first evaluation during training
        self.set_mode(mode)

        self._close_writers()
        self._increment_step()

        if self.prepared_collections is False:
            # sets prepared_collections to True here
            self._prepare_collections()

        if self._prepared_tensors[mode] is False:
            if self._validate_exec_function(self._get_exec_function(mode)):
                self._prepare_layers(mode)
                self._prepare_non_layer_tensors()
                # below should be after tensors are processed,
                # so we know that device map is populated
                if (len(self.device_map) and self.distribution_strategy
                        == TFDistributionStrategy.MIRRORED_STRATEGY
                        and self.save_all_workers is False):
                    self.chief_worker = sorted(self.device_map.keys())[0]
            # else:
            # this will delay the preparation of tensors as the
            # full graph is not built. Gradients are not available
            # at this stage for example

        if self._prepared_tensors[mode]:
            self._prepare_tensors_for_step(mode)
            if self.tensor_refs_to_save_this_step:
                # if saving metric, writer may not be initialized as a result
                self._initialize_writers()

            self._add_callbacks(mode)

    def on_train_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.TRAIN, logs=logs)

    def on_test_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.EVAL, logs=logs)

    def on_predict_batch_begin(self, batch, logs=None):
        self._on_any_batch_begin(batch, ModeKeys.PREDICT, logs=logs)

    def _on_any_batch_end(self, batch, mode, logs=None):
        if self._is_not_supported():
            return

        self._remove_fetches_and_callbacks(mode)
        self._save_tensors_post_step(batch, logs)

        if self._prepared_tensors[mode]:
            if self._exported_collections is False:
                # in keras, these collections change when mode changes
                # but rest of the project isn't yet capable of handling this
                # this means that collections like outputs, or other collections with intermediate tensors
                # will only have tensor names from first mode

                # this means sometimes collections will be exported after 1 step
                self.export_collections()
                self._exported_collections = True

            if self._exported_model[self.mode] is False:
                # confirmed that keras has same graph for all modes
                # but we are writing it multiple times to keep behavior consistent with
                # estimator and to make it easier when seeing tensorboard
                self._export_model()
                self._exported_model[self.mode] = True

    def on_train_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.TRAIN, logs=logs)

    def on_test_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.EVAL, logs=logs)

    def on_predict_batch_end(self, batch, logs=None):
        self._on_any_batch_end(batch, ModeKeys.PREDICT, logs=logs)

    def wrap_optimizer(self, optimizer):
        """
        Wrapping your optimizer with this method enables finding gradient tensors and optimizer
        variables.

        :param optimizer: tf.train.Optimizer or tf.keras.optimizers.Optimizer
            the optimizer object used for training
        :return: Wrapped optimizer of same type as passed.
            This optimizer should be used for training
        """
        if isinstance(optimizer, tf.train.Optimizer):
            optimizer = self._wrap_apply_gradients(optimizer)
        elif isinstance(optimizer, tf.keras.optimizers.Optimizer
                        ) or is_keras_optimizer(optimizer):
            # either subclasse of optimizerV2 class in tf.keras
            # or keras.optimizers.Optimizer
            original_get_grads = optimizer.__class__.get_gradients

            def new_get_grads(opt, loss, params):
                grads = original_get_grads(opt, loss, params)
                self.set_gradients(gradients=grads)
                return grads

            optimizer.__class__.get_gradients = new_get_grads

            if isinstance(optimizer, tf.keras.optimizers.Optimizer):
                try:
                    original_add_weight = optimizer.__class__.add_weight

                    def new_add_weight(opt, *args, **kwargs):
                        var = original_add_weight(opt, *args, **kwargs)
                        self.set_optimizer_variables(var)
                        return var

                    optimizer.__class__.add_weight = new_add_weight
                except AttributeError:
                    # TF 1.13 Keras Optimizers have no add_weight attribute,
                    # so optimizer_variables is not supported
                    pass
        else:
            self._log_unsupported_optimizer(optimizer)
        # Optimizer is being saved to support additional features in the future.
        self.optimizer = optimizer
        return optimizer